aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 12:54:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 12:57:46 -0700
commitf3c89936e97c99dead1ca3310246691c1b221adf (patch)
tree3c99b66936ed59028b32609115a239f52798907d /tensorflow/compiler
parent0b9b09a8531004b44b133a52c3fcc67bc6759bd8 (diff)
Merge changes from github.
END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit 1699d904a authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit 03da61134 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commit e89f04d80 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit 15a8df724 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commit b58d98353 authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commit a286b7db8 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/plugin/BUILD4
-rw-r--r--tensorflow/compiler/plugin/executor/BUILD32
-rw-r--r--tensorflow/compiler/plugin/executor/compiler.cc123
-rw-r--r--tensorflow/compiler/plugin/executor/compiler.h64
-rw-r--r--tensorflow/compiler/plugin/executor/device.cc60
-rw-r--r--tensorflow/compiler/plugin/executor/executable.cc147
-rw-r--r--tensorflow/compiler/plugin/executor/executable.h65
-rw-r--r--tensorflow/compiler/plugin/executor/executor.cc135
-rw-r--r--tensorflow/compiler/plugin/executor/executor.h213
-rw-r--r--tensorflow/compiler/plugin/executor/platform.cc125
-rw-r--r--tensorflow/compiler/plugin/executor/platform.h83
-rw-r--r--tensorflow/compiler/plugin/executor/platform_id.h31
-rw-r--r--tensorflow/compiler/plugin/executor/transfer_manager.cc187
-rw-r--r--tensorflow/compiler/plugin/executor/transfer_manager.h77
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc14
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc44
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc4
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc6
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h4
-rw-r--r--tensorflow/compiler/xla/literal_util.cc10
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc57
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc7
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h2
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc9
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc9
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc23
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc28
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc3
-rw-r--r--tensorflow/compiler/xla/util.h18
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
59 files changed, 1666 insertions, 181 deletions
diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD
index 4badd3a589..8c2e9a7c81 100644
--- a/tensorflow/compiler/plugin/BUILD
+++ b/tensorflow/compiler/plugin/BUILD
@@ -32,5 +32,7 @@ package(
cc_library(
name = "plugin",
- deps = [],
+ deps = [
+ "//tensorflow/compiler/plugin/executor:plugin_lib",
+ ],
)
diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD
new file mode 100644
index 0000000000..9bc706abdf
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/BUILD
@@ -0,0 +1,32 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "plugin_lib",
+ srcs = glob([
+ "*.cc",
+ ]),
+ hdrs = glob([
+ "*.h",
+ ]),
+ deps = [
+ "//tensorflow/compiler/jit:xla_jit_headers_lib",
+ "//tensorflow/compiler/xla:xla_headers_lib",
+ "//tensorflow/compiler/xla/service:hlo_evaluator",
+ "//third_party/eigen3",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@protobuf//:protobuf_headers",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc
new file mode 100644
index 0000000000..893ff152f0
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/compiler.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdlib.h>
+#include <fstream>
+
+#include "tensorflow/compiler/plugin/executor/compiler.h"
+#include "tensorflow/compiler/plugin/executor/executable.h"
+
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/strcat.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace se = ::perftools::gputools;
+namespace sep = ::perftools::gputools::executorplugin;
+namespace port = ::perftools::gputools::port;
+
+namespace xla {
+namespace executorplugin {
+
+/*
+ * Run optimization passes on the module. The graph is transformed by
+ * each pass in the optimization pipeline. The service subdirectory
+ * contains useful optimization passes.
+ */
+Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module,
+ HloDumper dump_hlo) {
+ HloPassPipeline pipeline("Executor", dump_hlo);
+ pipeline.AddPass<Inliner>();
+ pipeline.AddPass<HloSubcomputationUnification>();
+ pipeline.AddPass<HloCSE>(false);
+
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ false, [](const Shape&, const Shape&) { return false; });
+ pipeline.AddPass<ReshapeMover>();
+ pipeline.AddPass<HloConstantFolding>();
+ pipeline.AddPass<HloCSE>(true);
+
+ pipeline.AddPass<HloDCE>();
+ pipeline.AddPass<FlattenCallGraph>();
+ return pipeline.Run(hlo_module).status();
+}
+
+StatusOr<std::unique_ptr<Executable>> ExecutorCompiler::Compile(
+ std::unique_ptr<HloModule> hlo_module, HloDumper dump_hlo,
+ se::StreamExecutor* stream_exec) {
+ TF_RET_CHECK(stream_exec != nullptr);
+
+ VLOG(1) << "Generate graph " << hlo_module->name();
+
+ TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get(), dump_hlo));
+
+ // Typically you would visit the HLO graph, building up a compiled equivalent
+ // In this case we are using an Hlo evaluator at execution time, so we don't
+ // need to compile anything
+
+ // Create executable from only the Hlo module
+ std::unique_ptr<Executable> executable;
+ executable.reset(new ExecutorExecutable(std::move(hlo_module)));
+
+ return std::move(executable);
+}
+
+StatusOr<std::vector<std::unique_ptr<Executable>>> ExecutorCompiler::Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) {
+
+ return tensorflow::errors::Unimplemented(
+ "Compilation of multiple HLO modules is not supported on Executor.");
+}
+
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ExecutorCompiler::CompileAheadOfTime(
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ HloDumper dump_hlo, const AotCompilationOptions& aot_options) {
+
+ return tensorflow::errors::InvalidArgument(
+ "AOT compilation not supported on Executor");
+}
+
+se::Platform::Id ExecutorCompiler::PlatformId() const {
+ return sep::kExecutorPlatformId;
+}
+
+HloCostAnalysis::ShapeSizeFunction
+ExecutorCompiler::ShapeSizeBytesFunction() const {
+ return ExecutorExecutable::ShapeSizeBytes;
+}
+
+
+} // namespace executorplugin
+} // namespace xla
+
+REGISTER_MODULE_INITIALIZER(executor_compiler, {
+ xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() {
+ return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>();
+ });
+});
diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h
new file mode 100644
index 0000000000..8fe591c8ab
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/compiler.h
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+namespace xla {
+namespace executorplugin {
+
+class ExecutorCompiler : public Compiler {
+ public:
+ ExecutorCompiler() {}
+ ~ExecutorCompiler() override {}
+
+ StatusOr<std::unique_ptr<Executable>> Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ HloDumper dump_hlo,
+ perftools::gputools::StreamExecutor* stream_exec) override;
+
+ StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_module,
+ HloDumper dump_hlo,
+ std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
+
+ StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(
+ std::vector<std::unique_ptr<HloModule>> module,
+ HloDumper dump_hlo, const AotCompilationOptions& options) override;
+
+ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
+
+ perftools::gputools::Platform::Id PlatformId() const override;
+
+ private:
+ Status RunHloOptimization(HloModule* hlo_module, HloDumper dump_hlo);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorCompiler);
+};
+
+} // namespace executorplugin
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_COMPILER_H_
diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc
new file mode 100644
index 0000000000..bbc39dc03f
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/device.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_device_launch_op.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+
+const char* const DEVICE_XLA_EXEC = "XLA_EXEC";
+const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT";
+
+constexpr std::array<DataType, 5> kExecAllTypes = {
+ {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}};
+
+class XlaExaDeviceFactory : public DeviceFactory {
+ public:
+ Status CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
+};
+
+Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ static XlaDeviceOpRegistrations* registrations =
+ RegisterXlaDeviceKernels(DEVICE_XLA_EXEC, DEVICE_EXEC_XLA_JIT);
+ (void)registrations;
+
+ std::unique_ptr<XlaDevice> device;
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Executor", DEVICE_XLA_EXEC, 0,
+ DEVICE_EXEC_XLA_JIT, options,
+ name_prefix, &device));
+ devices->push_back(device.release());
+ return Status::OK();
+}
+
+REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110);
+
+// Kernel registrations
+
+static bool OpFilter(KernelDef* kdef) { return true; }
+
+REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_EXEC, XlaDeviceLaunchOp, kExecAllTypes);
+REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_EXEC, kExecAllTypes);
+REGISTER_XLA_BACKEND(DEVICE_EXEC_XLA_JIT, kExecAllTypes, OpFilter);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc
new file mode 100644
index 0000000000..79eea9af3f
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executable.cc
@@ -0,0 +1,147 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/executable.h"
+#include "tensorflow/compiler/plugin/executor/executor.h"
+
+#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace se = ::perftools::gputools;
+namespace sep = ::perftools::gputools::executorplugin;
+
+namespace xla {
+namespace executorplugin {
+
+ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module)
+ : Executable(std::move(hlo_module), ShapeSizeBytes) {}
+
+ExecutorExecutable::~ExecutorExecutable() {}
+
+static se::DeviceMemoryBase AllocateSingleOutput(sep::ExecutorExecutor* executor,
+ const Literal& literal) {
+ int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape()));
+ void* buf = executor->Allocate(size);
+ const void* src = literal.InternalData();
+ memcpy(buf, src, size);
+ return se::DeviceMemoryBase(buf, size);
+}
+
+static se::DeviceMemoryBase AllocateOutputBuffer(sep::ExecutorExecutor* executor,
+ const Literal& literal) {
+ const Shape& shape = literal.shape();
+ if (shape.element_type() != xla::TUPLE) {
+ return AllocateSingleOutput(executor, literal);
+ } else {
+ int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*)));
+ void** buf = reinterpret_cast<void**>(executor->Allocate(size));
+ for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) {
+ se::DeviceMemoryBase out =
+ AllocateSingleOutput(executor, literal.tuple_literals(n));
+ *buf++ = out.opaque();
+ }
+
+ return se::DeviceMemoryBase(buf, size);
+ }
+}
+
+StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+
+ VLOG(1) << "Execute " << module().name();
+ if (VLOG_IS_ON(2)) {
+ for (const auto& a : arguments) {
+ VLOG(2) << "-- argument " << a.opaque();
+ }
+ }
+
+ uint64 start_micros = tensorflow::Env::Default()->NowMicros();
+
+ HloComputation* computation = module().entry_computation();
+ if (computation->num_parameters() != arguments.size()) {
+ return tensorflow::errors::Internal(
+ "Mismatch between argument count and graph parameter count.");
+ }
+
+ // Create the arguments as an vector of XLA literals
+ std::vector<std::unique_ptr<Literal>> arg_literals;
+ std::vector<Literal*> arg_literals_ptrs;
+ for (int64 p = 0; p < computation->num_parameters(); p++) {
+ // Create the input literal for the parameter
+ HloInstruction* param = computation->parameter_instruction(p);
+ arg_literals.emplace_back(Literal::CreateFromShape(param->shape()));
+ arg_literals_ptrs.push_back(arg_literals.back().get());
+
+ // Copy in the data from the stream_executor buffers
+ void* buffer = arg_literals.back().get()->MutableInternalData();
+ memcpy(buffer, arguments[p].opaque(),
+ ShapeUtil::ByteSizeOf(param->shape()));
+ }
+
+ // Execute the graph using the evaluator
+ HloEvaluator evaluator;
+ std::unique_ptr<Literal> output;
+ TF_ASSIGN_OR_RETURN(output,
+ evaluator.Evaluate(computation, arg_literals_ptrs));
+
+ // Copy the result into the return buffer
+ perftools::gputools::StreamExecutor* executor(stream->parent());
+ sep::ExecutorExecutor* executorExecutor(
+ static_cast<sep::ExecutorExecutor*>(executor->implementation()));
+
+ se::DeviceMemoryBase ret =
+ AllocateOutputBuffer(executorExecutor, *(output.get()));
+
+ uint64 end_micros = tensorflow::Env::Default()->NowMicros();
+
+ {
+ tensorflow::mutex_lock lock(mutex_);
+ const double nanoseconds = (end_micros - start_micros) * 1000.0;
+ execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+ }
+
+ return ret;
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> ExecutorExecutable::ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ return tensorflow::errors::Unimplemented(
+ "ExecuteOnStream is not yet supported on Executor.");
+}
+
+StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ return tensorflow::errors::Unimplemented(
+ "ExecuteAsyncOnStream is not yet supported on Executor.");
+}
+
+/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) {
+ if (ShapeUtil::IsOpaque(shape)) {
+ return sizeof(void*);
+ }
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+
+} // namespace executorplugin
+} // namespace xla
diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/plugin/executor/executable.h
new file mode 100644
index 0000000000..ba3d4da21d
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executable.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+namespace xla {
+namespace executorplugin {
+
+class ExecutorExecutable : public Executable {
+ public:
+ ExecutorExecutable(std::unique_ptr<HloModule> hlo_module);
+ ~ExecutorExecutable() override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments) override;
+
+ static int64 ShapeSizeBytes(const Shape& shape);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable);
+};
+
+} // namespace executorplugin
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_EXECUTABLE_H_
diff --git a/tensorflow/compiler/plugin/executor/executor.cc b/tensorflow/compiler/plugin/executor/executor.cc
new file mode 100644
index 0000000000..e72c2711f7
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executor.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/executor.h"
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+#include "tensorflow/compiler/xla/status_macros.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+namespace se = ::perftools::gputools;
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+host::HostStream *AsExecutorStream(Stream *stream) {
+ DCHECK(stream != nullptr);
+ return dynamic_cast<host::HostStream *>(stream->implementation());
+}
+
+ExecutorExecutor::ExecutorExecutor(const PluginConfig &plugin_config)
+ : plugin_config_(plugin_config) {}
+
+ExecutorExecutor::~ExecutorExecutor() {}
+
+void *ExecutorExecutor::Allocate(uint64 size) {
+ void *buf = new char[size];
+ return buf;
+}
+
+void *ExecutorExecutor::AllocateSubBuffer(DeviceMemoryBase *parent,
+ uint64 offset_bytes,
+ uint64 size_bytes) {
+ return parent + offset_bytes;
+}
+
+void ExecutorExecutor::Deallocate(DeviceMemoryBase *mem) {
+ if (!mem->is_sub_buffer()) {
+ delete[] static_cast<char *>(mem->opaque());
+ }
+}
+
+bool ExecutorExecutor::Memcpy(Stream *stream, void *host_dst,
+ const DeviceMemoryBase &dev_src, uint64 size) {
+ AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
+ port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
+ });
+ return true;
+}
+
+bool ExecutorExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
+ const void *host_src, uint64 size) {
+ AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
+ port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
+ });
+ return true;
+}
+
+port::Status ExecutorExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst,
+ const void *host_src,
+ uint64 size) {
+ memcpy(dev_dst->opaque(), host_src, size);
+ return port::Status::OK();
+}
+
+port::Status ExecutorExecutor::SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &dev_src,
+ uint64 size) {
+ memcpy(host_dst, dev_src.opaque(), size);
+ return port::Status::OK();
+}
+
+bool ExecutorExecutor::HostCallback(Stream *stream,
+ std::function<void()> callback) {
+ AsExecutorStream(stream)->EnqueueTask(callback);
+ return true;
+}
+
+bool ExecutorExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
+ AsExecutorStream(dependent)->EnqueueTask(
+ [other]() { other->BlockHostUntilDone(); });
+ AsExecutorStream(dependent)->BlockUntilDone();
+ return true;
+}
+
+bool ExecutorExecutor::StartTimer(Stream *stream, Timer *timer) {
+ dynamic_cast<host::HostTimer *>(timer->implementation())->Start(stream);
+ return true;
+}
+
+bool ExecutorExecutor::StopTimer(Stream *stream, Timer *timer) {
+ dynamic_cast<host::HostTimer *>(timer->implementation())->Stop(stream);
+ return true;
+}
+
+bool ExecutorExecutor::BlockHostUntilDone(Stream *stream) {
+ AsExecutorStream(stream)->BlockUntilDone();
+ return true;
+}
+
+DeviceDescription *ExecutorExecutor::PopulateDeviceDescription() const {
+ internal::DeviceDescriptionBuilder builder;
+
+ builder.set_device_address_bits(64);
+
+ builder.set_name("Executor");
+ builder.set_device_vendor("VectorName");
+ builder.set_platform_version("1.0");
+ builder.set_driver_version("1.0");
+ builder.set_runtime_version("1.0");
+ builder.set_pci_bus_id("1");
+ builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
+ builder.set_clock_rate_ghz(static_cast<float>(CLOCKS_PER_SEC) / 1e9);
+
+ auto built = builder.Build();
+ return built.release();
+}
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/compiler/plugin/executor/executor.h b/tensorflow/compiler/plugin/executor/executor.h
new file mode 100644
index 0000000000..32fdb157e4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/executor.h
@@ -0,0 +1,213 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Declares the ExecutorExecutor class, which is a CPU-only implementation of
+// the StreamExecutor interface. For now, this is used for testing and to
+// examine the performance of host-based StreamExecutor code.
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
+
+#include "tensorflow/stream_executor/host/host_stream.h"
+#include "tensorflow/stream_executor/host/host_timer.h"
+
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/rng.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+#include <list>
+#include <mutex>
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+using Args = tensorflow::gtl::ArraySlice<DeviceMemoryBase>;
+
+class ExecutorExecutor : public internal::StreamExecutorInterface {
+ public:
+ explicit ExecutorExecutor(const PluginConfig &plugin_config);
+ ~ExecutorExecutor() override;
+
+ port::Status Init(int device_ordinal, DeviceOptions device_options) override {
+ return port::Status::OK();
+ }
+
+ bool GetKernel(const MultiKernelLoaderSpec &spec,
+ KernelBase *kernel) override {
+ return false;
+ }
+ bool Launch(Stream *stream, const ThreadDim &thread_dims,
+ const BlockDim &block_dims, const KernelBase &kernel,
+ const KernelArgsArrayBase &args) override {
+ return false;
+ }
+
+ void *Allocate(uint64 size) override;
+ void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes,
+ uint64 size_bytes) override;
+ void Deallocate(DeviceMemoryBase *mem) override;
+
+ void *HostMemoryAllocate(uint64 size) override { return new char[size]; }
+ void HostMemoryDeallocate(void *mem) override {
+ delete[] static_cast<char *>(mem);
+ }
+ bool HostMemoryRegister(void *mem, uint64 size) override { return true; }
+ bool HostMemoryUnregister(void *mem) override { return true; }
+
+ bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src,
+ uint64 size) override;
+ bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src,
+ uint64 size) override;
+ bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
+ const DeviceMemoryBase &host_src,
+ uint64 size) override {
+ return false;
+ }
+
+ bool MemZero(Stream *stream, DeviceMemoryBase *location,
+ uint64 size) override {
+ return false;
+ }
+ bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
+ uint64 size) override {
+ return false;
+ }
+ bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
+ uint64 size) override {
+ return false;
+ }
+
+ // No "synchronize all activity" implemented for this platform at the moment.
+ bool SynchronizeAllActivity() override { return false; }
+ bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
+ return false;
+ }
+
+ bool SynchronousMemSet(DeviceMemoryBase *location, int value,
+ uint64 size) override {
+ return false;
+ }
+
+ port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst,
+ const void *host_src, uint64 size) override;
+ port::Status SynchronousMemcpy(void *host_dst,
+ const DeviceMemoryBase &pop_src,
+ uint64 size) override;
+ port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst,
+ const DeviceMemoryBase &pop_src,
+ uint64 size) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ bool HostCallback(Stream *stream, std::function<void()> callback) override;
+
+ port::Status AllocateEvent(Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status DeallocateEvent(Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status RecordEvent(Stream *stream, Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ port::Status WaitForEvent(Stream *stream, Event *event) override {
+ return port::Status{port::error::UNIMPLEMENTED, ""};
+ }
+
+ Event::Status PollForEventStatus(Event *event) override {
+ return Event::Status::kError;
+ }
+
+ bool AllocateStream(Stream *stream) override { return true; }
+ void DeallocateStream(Stream *stream) override {}
+ bool CreateStreamDependency(Stream *dependent, Stream *other) override;
+
+ bool AllocateTimer(Timer *timer) override { return true; }
+ void DeallocateTimer(Timer *timer) override {}
+ bool StartTimer(Stream *stream, Timer *timer) override;
+ bool StopTimer(Stream *stream, Timer *timer) override;
+
+ bool BlockHostUntilDone(Stream *stream) override;
+
+ int PlatformDeviceCount() override { return 1; }
+
+ bool DeviceMemoryUsage(int64 *free, int64 *total) const override {
+ return false;
+ }
+
+ DeviceDescription *PopulateDeviceDescription() const override;
+
+ port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
+ return port::Status::OK();
+ }
+
+ bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override {
+ return true;
+ }
+
+ SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
+ return SharedMemoryConfig::kDefault;
+ }
+
+ port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override {
+ return port::Status{port::error::UNIMPLEMENTED,
+ "Shared memory not supported"};
+ }
+
+ std::unique_ptr<internal::EventInterface> CreateEventImplementation()
+ override {
+ return nullptr;
+ }
+
+ std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
+ override {
+ return nullptr;
+ }
+
+ std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
+ override {
+ return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
+ }
+
+ std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
+ return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
+ }
+
+ port::StatusOr<DeviceMemoryBase> ExecuteGraph(const xla::Shape &shape,
+ Args args);
+
+ private:
+ DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
+
+ port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
+ const xla::Shape &shape);
+
+ const PluginConfig plugin_config_;
+};
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_EXECUTOR_H_
diff --git a/tensorflow/compiler/plugin/executor/platform.cc b/tensorflow/compiler/plugin/executor/platform.cc
new file mode 100644
index 0000000000..2f339f04a7
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/platform.cc
@@ -0,0 +1,125 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/platform.h"
+#include "tensorflow/compiler/plugin/executor/executor.h"
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/ptr_util.h"
+#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/lib/status_macros.h"
+#include "tensorflow/stream_executor/lib/stringprintf.h"
+
+namespace se = ::perftools::gputools;
+namespace sep = ::perftools::gputools::executorplugin;
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+PLATFORM_DEFINE_ID(kExecutorPlatformId);
+
+ExecutorPlatform::ExecutorPlatform() : name_("Executor") {}
+
+ExecutorPlatform::~ExecutorPlatform() {}
+
+Platform::Id ExecutorPlatform::id() const { return kExecutorPlatformId; }
+
+int ExecutorPlatform::VisibleDeviceCount() const { return 1; }
+
+const string& ExecutorPlatform::Name() const { return name_; }
+
+port::StatusOr<StreamExecutor*> ExecutorPlatform::ExecutorForDevice(
+ int ordinal) {
+ StreamExecutorConfig config;
+ config.ordinal = ordinal;
+ config.plugin_config = PluginConfig();
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*>
+ExecutorPlatform::ExecutorForDeviceWithPluginConfig(
+ int device_ordinal, const PluginConfig& plugin_config) {
+ StreamExecutorConfig config;
+ config.ordinal = device_ordinal;
+ config.plugin_config = plugin_config;
+ config.device_options = DeviceOptions::Default();
+ return GetExecutor(config);
+}
+
+port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor(
+ const StreamExecutorConfig& config) {
+ mutex_lock lock(executors_mutex_);
+
+ port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
+ if (status.ok()) {
+ return status.ValueOrDie();
+ }
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
+ GetUncachedExecutor(config);
+ if (!executor.ok()) {
+ return executor.status();
+ }
+
+ StreamExecutor* naked_executor = executor.ValueOrDie().get();
+ SE_RETURN_IF_ERROR(
+ executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
+ return naked_executor;
+}
+
+port::StatusOr<std::unique_ptr<StreamExecutor>>
+ExecutorPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
+ auto executor = port::MakeUnique<StreamExecutor>(
+ this, port::MakeUnique<ExecutorExecutor>(config.plugin_config));
+ auto init_status = executor->Init(config.ordinal, config.device_options);
+ if (!init_status.ok()) {
+ return port::Status{
+ port::error::INTERNAL,
+ port::Printf(
+ "failed initializing StreamExecutor for device ordinal %d: %s",
+ config.ordinal, init_status.ToString().c_str())};
+ }
+
+ return std::move(executor);
+}
+
+void ExecutorPlatform::RegisterTraceListener(
+ std::unique_ptr<TraceListener> listener) {
+ LOG(FATAL) << "not yet implemented: register executor trace listener";
+}
+
+void ExecutorPlatform::UnregisterTraceListener(TraceListener* listener) {
+ LOG(FATAL) << "not yet implemented: unregister executor trace listener";
+}
+
+static void InitializeExecutorPlatform() {
+ std::unique_ptr<se::Platform> platform(new sep::ExecutorPlatform);
+ SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform)));
+}
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+REGISTER_MODULE_INITIALIZER(executor_platform, sep::InitializeExecutorPlatform());
+
+DECLARE_MODULE_INITIALIZER(multi_platform_manager);
+// Note that module initialization sequencing is not supported in the
+// open-source project, so this will be a no-op there.
+REGISTER_MODULE_INITIALIZER_SEQUENCE(executor_platform, multi_platform_manager);
diff --git a/tensorflow/compiler/plugin/executor/platform.h b/tensorflow/compiler/plugin/executor/platform.h
new file mode 100644
index 0000000000..c252a589d4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/platform.h
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/stream_executor/executor_cache.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/platform.h"
+#include "tensorflow/stream_executor/platform/mutex.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
+#include "tensorflow/stream_executor/stream_executor_pimpl.h"
+#include "tensorflow/stream_executor/trace_listener.h"
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+class ExecutorPlatform : public Platform {
+ public:
+ ExecutorPlatform();
+ ~ExecutorPlatform() override;
+
+ Platform::Id id() const override;
+
+ // Device count is less clear-cut for CPUs than accelerators. This call
+ // currently returns the number of thread units in the host, as reported by
+ // base::NumCPUs().
+ int VisibleDeviceCount() const override;
+
+ const string& Name() const override;
+
+ port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
+
+ port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
+ int ordinal, const PluginConfig& config) override;
+
+ port::StatusOr<StreamExecutor*> GetExecutor(
+ const StreamExecutorConfig& config) override;
+
+ port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
+ const StreamExecutorConfig& config) override;
+
+ void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override;
+
+ void UnregisterTraceListener(TraceListener* listener) override;
+
+ private:
+ // This platform's name.
+ string name_;
+
+ // mutex that guards the ordinal-to-executor map.
+ mutable mutex executors_mutex_;
+
+ // Cache of created StreamExecutors.
+ ExecutorCache executor_cache_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(ExecutorPlatform);
+};
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_STREAM_EXECUTOR_EXECUTOR_PLATFORM_H_
diff --git a/tensorflow/compiler/plugin/executor/platform_id.h b/tensorflow/compiler/plugin/executor/platform_id.h
new file mode 100644
index 0000000000..8d2b29a3e4
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/platform_id.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_
+#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_
+
+#include "tensorflow/stream_executor/platform.h"
+
+namespace perftools {
+namespace gputools {
+namespace executorplugin {
+
+extern const Platform::Id kExecutorPlatformId;
+
+} // namespace executorplugin
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_PLATFORM_ID_H_
diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc
new file mode 100644
index 0000000000..51c5deeea5
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/plugin/executor/transfer_manager.h"
+#include "tensorflow/compiler/plugin/executor/platform_id.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace sep = ::perftools::gputools::executorplugin;
+
+namespace xla {
+namespace executorplugin {
+
+ExecutorTransferManager::ExecutorTransferManager() {}
+
+se::Platform::Id ExecutorTransferManager::PlatformId() const {
+ return se::executorplugin::kExecutorPlatformId;
+}
+
+Status ExecutorTransferManager::TransferLiteralFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
+ TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape));
+
+ // Tuples are a special case and contain one or more shapes inside of them to
+ // an arbitrary nesting depth.
+ if (device_shape.element_type() == TUPLE) {
+ *literal->mutable_shape() = literal_shape;
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ ShallowCopyTupleFromDevice(executor, source, device_shape));
+ TF_RET_CHECK(element_buffers.size() ==
+ ShapeUtil::TupleElementCount(device_shape));
+ for (int64 i = 0; i < element_buffers.size(); ++i) {
+ const Shape& element_device_shape = device_shape.tuple_shapes(i);
+ const Shape& element_literal_shape = literal_shape.tuple_shapes(i);
+ Literal* element_literal = literal->add_tuple_literals();
+ // Recursively call TransferFromDevice to copy over the data in the
+ // element array.
+ TF_RETURN_IF_ERROR(TransferLiteralFromDevice(
+ executor, element_buffers[i], element_device_shape,
+ element_literal_shape, element_literal));
+ }
+ return Status::OK();
+ }
+
+ *literal->mutable_shape() = device_shape;
+ literal->Reserve(ShapeUtil::ElementsIn(device_shape));
+ TF_RETURN_IF_ERROR(TransferBufferFromDevice(
+ executor, source, ShapeUtil::ByteSizeOf(device_shape),
+ literal->MutableInternalData()));
+ if (!ShapeUtil::Equal(literal_shape, device_shape)) {
+ literal->Swap(
+ literal->Relayout(literal_shape.layout()).get());
+ }
+ TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape()));
+ return Status::OK();
+}
+
+StatusOr<std::vector<se::DeviceMemoryBase>>
+ExecutorTransferManager::ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(shape));
+
+ std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
+ nullptr);
+ int64 tuple_size = ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+ auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
+ element_pointers.data());
+ if (!copy_status.ok()) {
+ return AddStatus(
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
+ }
+
+ // Create a DeviceMemoryBase from each void* pointer.
+ std::vector<se::DeviceMemoryBase> destination;
+ for (int i = 0; i < element_pointers.size(); ++i) {
+ if (element_pointers[i] == nullptr &&
+ !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
+ return FailedPrecondition("tuple contains nullptr at element %d", i);
+ }
+ int64 buffer_size =
+ ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), sizeof(void*));
+ destination.emplace_back(element_pointers[i], buffer_size);
+ }
+ return std::move(destination);
+}
+
+Status ExecutorTransferManager::TransferLiteralToDevice(
+ se::StreamExecutor* executor, const Literal& literal,
+ se::DeviceMemoryBase* destination) {
+ const Shape& shape = literal.shape();
+
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ std::vector<void*> tuple_elements_on_device;
+ for (const Literal& tuple_element : literal.tuple_literals()) {
+ se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>(
+ GetByteSizeRequirement(tuple_element.shape()));
+ TF_RETURN_IF_ERROR(
+ TransferLiteralToDevice(executor, tuple_element, &allocation));
+ tuple_elements_on_device.push_back(allocation.opaque());
+ }
+ return TransferBufferToDevice(
+ executor, tuple_elements_on_device.size() * sizeof(void*),
+ tuple_elements_on_device.data(), destination);
+ }
+
+ return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
+ literal.InternalData(),
+ destination);
+}
+
+Status ExecutorTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const Literal& literal) {
+ const Shape& shape = literal.shape();
+ VLOG(1) << "transferring literal shape to infeed: "
+ << ShapeUtil::HumanString(shape);
+
+ return Status::OK();
+}
+
+Status ExecutorTransferManager::TransferBufferToInfeed(
+ se::StreamExecutor* executor, int64 size, const void* source) {
+ return Unimplemented("Transfer to Infeed");
+}
+
+Status ExecutorTransferManager::TransferLiteralFromOutfeed(
+ perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
+ Literal* literal) {
+ const Shape& shape = literal->shape();
+ VLOG(1) << "transferring literal shape from outfeed: "
+ << ShapeUtil::HumanString(shape);
+
+ return Status::OK();
+}
+
+Status ExecutorTransferManager::ResetDevices(
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ executors) {
+ return Unimplemented("Device reset not supported");
+}
+
+int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+} // namespace executorplugin
+} // namespace xla
+
+static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() {
+ return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>();
+}
+
+static bool InitModule() {
+ xla::TransferManager::RegisterTransferManager(sep::kExecutorPlatformId,
+ &CreateExecutorTransferManager);
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.h b/tensorflow/compiler/plugin/executor/transfer_manager.h
new file mode 100644
index 0000000000..7a42e5a2d7
--- /dev/null
+++ b/tensorflow/compiler/plugin/executor/transfer_manager.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
+
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+#include <vector>
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace executorplugin {
+
+class ExecutorTransferManager : public TransferManager {
+ public:
+ ExecutorTransferManager();
+
+ ~ExecutorTransferManager() override {}
+
+ se::Platform::Id PlatformId() const override;
+
+ StatusOr<std::vector<se::DeviceMemoryBase>> ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) override;
+
+ Status TransferLiteralFromDevice(se::StreamExecutor* executor,
+ const se::DeviceMemoryBase& source,
+ const Shape& device_shape,
+ const Shape& literal_shape,
+ Literal* literal) override;
+
+ Status TransferLiteralToDevice(se::StreamExecutor* executor,
+ const Literal& literal,
+ se::DeviceMemoryBase* destination) override;
+
+ Status TransferLiteralToInfeed(se::StreamExecutor* executor,
+ const Literal& literal) override;
+
+ Status TransferBufferToInfeed(se::StreamExecutor* executor,
+ int64 size, const void* source) override;
+
+ Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
+ const Shape& literal_shape,
+ Literal* literal) override;
+
+ Status ResetDevices(
+ tensorflow::gtl::ArraySlice<se::StreamExecutor*> executors) override;
+
+ int64 GetByteSizeRequirement(const Shape& shape) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorTransferManager);
+};
+
+} // namespace executorplugin
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_EXECUTOR_DRIVER_EXECUTOR_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 6b328fb618..a75a5cd2cf 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -218,7 +218,7 @@ class FtrlOptimizerTest(XLATestCase):
self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval())
self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval())
- # When variables are intialized with Zero, FTRL-Proximal has two properties:
+ # When variables are initialized with Zero, FTRL-Proximal has two properties:
# 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
# with GradientDescent.
# 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is idential
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
index f752fb3ae2..16b778bca4 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -94,12 +94,14 @@ class BatchMatMulOp : public XlaOpKernel {
// Slice off individual matrices and reshape to 2D tensors.
auto x_slice = builder->Slice(
x_flat, {i, 0, 0},
- {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
+ {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)},
+ {1, 1, 1});
x_slice = builder->Reshape(
x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
auto y_slice = builder->Slice(
y_flat, {i, 0, 0},
- {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
+ {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)},
+ {1, 1, 1});
y_slice = builder->Reshape(
y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 47d2d747e6..21d3e64872 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -125,6 +125,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
// input_shape[M+1], ..., input_shape[N-1]]
std::vector<int64> start_indices(input_rank, 0);
std::vector<int64> end_indices = reshaped_permuted_shape;
+ std::vector<int64> strides(input_rank, 1);
for (int i = 0; i < block_rank; ++i) {
int64 crop_start = crops.Get<int64>({i, 0});
int64 crop_end = crops.Get<int64>({i, 1});
@@ -139,7 +140,7 @@ void BatchToSpace(XlaOpKernelContext* ctx,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
xla::ComputationDataHandle output =
- b->Slice(reshaped_permuted, start_indices, end_indices);
+ b->Slice(reshaped_permuted, start_indices, end_indices, strides);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
index 92b371cc4e..852d2a966e 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
@@ -172,15 +172,14 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel {
} else {
// These will be used to define the bounds of each slice.
// Within the loop, the input_channel index will be modified.
- gtl::InlinedVector<int64, 4> filter_begin;
- gtl::InlinedVector<int64, 4> filter_limits;
- gtl::InlinedVector<int64, 4> input_begin;
- gtl::InlinedVector<int64, 4> input_limits;
+ gtl::InlinedVector<int64, 4> filter_begin(4, 0);
+ gtl::InlinedVector<int64, 4> filter_limits(4);
+ gtl::InlinedVector<int64, 4> input_begin(4, 0);
+ gtl::InlinedVector<int64, 4> input_limits(4);
+ gtl::InlinedVector<int64, 4> strides(4, 1);
for (int i = 0; i < 4; ++i) {
- filter_begin.push_back(0);
- filter_limits.push_back(filter_shape.dim_size(i));
- input_begin.push_back(0);
- input_limits.push_back(input_shape.dim_size(i));
+ filter_limits[i] = filter_shape.dim_size(i);
+ input_limits[i] = input_shape.dim_size(i);
}
std::vector<int64> strides_for_tla{strides_[1], strides_[2]};
@@ -209,9 +208,9 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel {
input_limits[3] = i + 1;
xla::ComputationDataHandle filter_slice =
- b.Slice(filter, filter_begin, filter_limits);
+ b.Slice(filter, filter_begin, filter_limits, strides);
xla::ComputationDataHandle input_slice =
- b.Slice(input, input_begin, input_limits);
+ b.Slice(input, input_begin, input_limits, strides);
convs.push_back(b.ConvWithGeneralDimensions(
input_slice, filter_slice, strides_for_tla, xla_padding, dims));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 74994d8961..ec5017f6ab 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -125,7 +125,7 @@ class DiagPartOp : public XlaOpKernel {
diag = builder->Reshape(diag, {new_size, new_size + 1});
// Slices out the first column and reshapes to the final shape.
- diag = builder->Slice(diag, {0, 0}, {new_size, 1});
+ diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
diag = builder->Reshape(diag, new_dims);
ctx->SetOutput(0, diag);
@@ -224,8 +224,9 @@ class MatrixDiagPartOp : public XlaOpKernel {
} else if (actual_size > target_size) {
std::vector<int64> start(flattened_dims.size(), 0);
std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
+ std::vector<int64> strides(flattened_dims.size(), 1);
limits[flattened_dims.size() - 1] = target_size;
- diag = builder->Slice(diag, start, limits);
+ diag = builder->Slice(diag, start, limits, strides);
}
// Reshape so the target values are in the first position of the last
@@ -238,8 +239,9 @@ class MatrixDiagPartOp : public XlaOpKernel {
// Slices out the first column and reshapes to the final shape.
std::vector<int64> start(dims.size(), 0);
std::vector<int64> limits(dims.begin(), dims.end());
+ std::vector<int64> strides(dims.size(), 1);
limits[last_dim] = 1;
- diag = builder->Slice(diag, start, limits);
+ diag = builder->Slice(diag, start, limits, strides);
// Collapses away the last dimension.
dims.pop_back();
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index faa7ef0ef9..0330e34c98 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -156,6 +156,8 @@ class DynamicStitchOp : public XlaOpKernel {
indices0_shape.dims());
std::vector<int64> slice_limit(1 + data0_shape.dims() -
indices0_shape.dims());
+ std::vector<int64> stride(1 + data0_shape.dims() -
+ indices0_shape.dims(), 1);
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
@@ -168,7 +170,7 @@ class DynamicStitchOp : public XlaOpKernel {
// And place it in the concat list in the place indicated by
// the index.
to_concat[index_num] =
- ctx->builder()->Slice(expression, slice_start, slice_limit);
+ ctx->builder()->Slice(expression, slice_start, slice_limit, stride);
}
ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0));
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index 51c97d85d7..482c54a40c 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -54,7 +54,9 @@ class SliceOp : public XlaOpKernel {
for (int i = 0; i < begin.size(); ++i) {
limits.push_back(begin[i] + size[i]);
}
- ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits));
+ std::vector<int64> strides(begin.size(), 1);
+ ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits,
+ strides));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 017f3a110e..44ee81461e 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -77,14 +77,14 @@ class SplitOp : public XlaOpKernel {
// The vectors we will use to define the slice. The entry for the
// split dimensions varies for each output.
- std::vector<int64> begin;
- std::vector<int64> limits;
+ std::vector<int64> begin(input_shape.dims(), 0);
+ std::vector<int64> limits(input_shape.dims());
+ std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < input_shape.dims(); ++i) {
// Initially set up the limits to be the full size of the input:
// the split dimension is filled in below.
int64 dim = input_shape.dim_size(i);
- begin.push_back(0);
- limits.push_back(dim);
+ limits[i] = dim;
}
auto input = ctx->Input(1);
@@ -94,7 +94,7 @@ class SplitOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
begin[split_dim] = i * slice_size;
limits[split_dim] = (i + 1) * slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
}
}
};
@@ -188,7 +188,7 @@ class SplitVOp : public XlaOpKernel {
std::vector<int64> begin(input_shape.dims(), 0);
auto dim_sizes = input_shape.dim_sizes();
std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
-
+ std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < num_split; ++i) {
TensorShape output_shape(input_shape);
int slice_size = split_sizes_vec[i];
@@ -196,7 +196,7 @@ class SplitVOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
limits[split_dim] = begin[split_dim] + slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
begin[split_dim] = limits[split_dim];
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 8037e90791..6af4bd0496 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -72,55 +72,29 @@ class StridedSliceOp : public XlaOpKernel {
&dummy, &dummy, &dummy, &begin, &end, &strides));
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
- gtl::InlinedVector<int64, 4> slice_begin, slice_end;
- bool simple_strides = true;
+ gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
+
for (int i = 0; i < begin.size(); ++i) {
- simple_strides &= (std::abs(strides[i]) == 1);
if (strides[i] > 0) {
slice_begin.push_back(begin[i]);
slice_end.push_back(end[i]);
+ slice_strides.push_back(strides[i]);
} else {
// Negative stride: swap begin and end, add 1 because the interval
// is semi-open, and mark the dimension to be reversed.
- slice_begin.push_back(end[i] + 1);
- slice_end.push_back(begin[i] + 1);
+ slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
+ slice_end.push_back(input_shape.dim_size(i) - end[i] - 1);
+ slice_strides.push_back(-strides[i]);
dimensions_to_reverse.push_back(i);
}
}
- xla::ComputationDataHandle slice =
- ctx->builder()->Slice(ctx->Input(0), slice_begin, slice_end);
+
+ xla::ComputationDataHandle slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
}
- // If at least one of the strides is > 1 (or < -1) then use Slice
- // to pull out each of the strided slices, and Concat to put them
- // together again.
- if (!simple_strides) {
- // Re-adjust the begin and end now that the periphery has been
- // sliced away.
- for (int d = 0; d < strides.size(); ++d) {
- slice_end[d] -= slice_begin[d];
- slice_begin[d] = 0;
- }
-
- for (int d = 0; d < strides.size(); ++d) {
- int64 stride = std::abs(strides[d]);
- if (stride > 1) {
- std::vector<xla::ComputationDataHandle> to_concat;
- int64 end = slice_end[d];
- for (int64 i = 0; i < end; i += stride) {
- slice_begin[d] = i;
- slice_end[d] = i + 1;
- to_concat.push_back(
- ctx->builder()->Slice(slice, slice_begin, slice_end));
- }
- slice = ctx->builder()->ConcatInDim(to_concat, d);
- slice_begin[d] = 0;
- slice_end[d] = to_concat.size();
- }
- }
- }
+ slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 598b341002..9367c1ef22 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -318,7 +318,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
for (int i = 0; i < num_indices; ++i) {
// Slices the i-th index out of `indices`, and pads it with zeros in the
// minor dimensions to form an index into the TensorArray storage.
- auto index = b->Slice(indices, {i}, {i + 1});
+ auto index = b->Slice(indices, {i}, {i + 1}, {1});
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices = PadIndexWithZeros(b, index, ta_shape.dims() - 1);
@@ -381,16 +381,18 @@ class TensorArrayScatterOp : public XlaOpKernel {
std::vector<int64> value_starts(value_shape.dims(), 0);
auto value_ends = value_shape.dim_sizes();
+ std::vector<int64> value_strides(value_shape.dims(), 1);
+
// For every (index, value) pair, update the corresponding TensorArray
// storage.
for (int i = 0; i < num_indices; ++i) {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends);
+ auto slice = b->Slice(value, value_starts, value_ends, value_strides);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1});
+ auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices = PadIndexWithZeros(b, index, elem_shape.dims());
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
index a5ce78e520..f87586ba57 100644
--- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -66,6 +66,7 @@ class UnpackOp : public XlaOpKernel {
std::vector<int64> start_indices(input_shape.dims(), 0);
std::vector<int64> limit_indices(input_shape.dims());
+ std::vector<int64> strides(input_shape.dims(), 1);
for (int i = 0; i < input_shape.dims(); ++i) {
limit_indices[i] = input_shape.dim_size(i);
}
@@ -73,7 +74,8 @@ class UnpackOp : public XlaOpKernel {
for (int i = 0; i < num; ++i) {
start_indices[axis] = i;
limit_indices[axis] = i + 1;
- auto slice = ctx->builder()->Slice(input, start_indices, limit_indices);
+ auto slice = ctx->builder()->Slice(input, start_indices, limit_indices,
+ strides);
// Reshape to drop the 'axis' dimension.
auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes());
ctx->SetOutput(i, result);
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 735a69d596..dcc313707b 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -256,7 +256,8 @@ void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
ComputationDataHandle ComputationBuilder::Slice(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) {
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> stride) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
@@ -269,6 +270,9 @@ ComputationDataHandle ComputationBuilder::Slice(
for (int64 index : limit_indices) {
request.add_limit_indices(index);
}
+ for (int64 index : stride) {
+ request.add_stride(index);
+ }
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_slice_request() = request;
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 5dceb03281..b411346459 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -211,9 +211,11 @@ class ComputationBuilder {
//
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
+ // The stride parameter determines the stride over the slice
ComputationDataHandle Slice(const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices);
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> stride);
// Enqueues a slice operation onto the computation that slices the 'operand'
// from dynamic start indices which are passed in 'start_indices'.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 1b125e3596..b6bd1158d2 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -1205,11 +1205,7 @@ void Literal::Resize<double>(int64 num_elements, double value) {
template <>
void Literal::Resize<half>(int64 num_elements, half value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_f16s()->resize(num_elements * sizeof(half));
- auto data = GetMutableArraySlice<half>();
- for (int i = 0; i < num_elements; i++) {
- data[i] = value;
- }
+ mutable_f16s()->resize(num_elements, value);
}
template <typename RepeatedFieldT, typename NativeT>
@@ -1252,7 +1248,7 @@ LiteralProto Literal::ToProto() const {
case F16:
*proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()),
- f16s_.size() / sizeof(half));
+ f16s_.size() * sizeof(half));
break;
case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s());
@@ -1308,7 +1304,7 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
const string& s(literal_proto.f16s());
CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half));
- memcpy(f16s_.data(), s.data(), s.size() / sizeof(half));
+ memcpy(f16s_.data(), s.data(), s.size());
break;
}
case F32:
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index ffae623b0c..5a550ef4c6 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -939,5 +939,62 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
}
}
+// Note that f16 is currently stored in a byte array in little endian byte order
+TEST_F(LiteralUtilTest, ToProto_f16) {
+ half h1(1.0f);
+ half h2(2.0f);
+
+ auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
+ Literal* l = m.get();
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
+ EXPECT_EQ(4, l->f16s().size());
+ EXPECT_EQ(4, l->f16s_size());
+
+ LiteralProto p = l->ToProto();
+ EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
+ EXPECT_EQ(8, p.f16s().size());
+ const char* d = p.f16s().data();
+ EXPECT_EQ(d[0], 0);
+ EXPECT_EQ(d[1], 0x3C);
+ EXPECT_EQ(d[2], 0);
+ EXPECT_EQ(d[3], 0x40);
+ EXPECT_EQ(d[4], 0);
+ EXPECT_EQ(d[5], 0x40);
+ EXPECT_EQ(d[6], 0);
+ EXPECT_EQ(d[7], 0x3C);
+}
+
+// Note that f16 is currently stored in a byte array in little endian byte order
+TEST_F(LiteralUtilTest, CopyFromProto_f16) {
+ half h1(1.0f);
+ half h2(2.0f);
+
+ const char half_vals[8] = {
+ 0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C
+ };
+ LiteralProto p;
+ p.mutable_shape()->set_element_type(F16);
+ p.mutable_shape()->clear_dimensions();
+ p.mutable_shape()->add_dimensions(4);
+ p.clear_f16s();
+ p.set_f16s(half_vals, 8);
+
+
+ Literal literal(p);
+ ASSERT_EQ(4, literal.f16s_size());
+ ASSERT_EQ(h1, literal.f16s(0));
+ ASSERT_EQ(h2, literal.f16s(1));
+ ASSERT_EQ(h2, literal.f16s(2));
+ ASSERT_EQ(h1, literal.f16s(3));
+
+ const std::vector<half>& r = literal.f16s();
+ ASSERT_EQ(4, r.size());
+ ASSERT_EQ(h1, r[0]);
+ ASSERT_EQ(h2, r[1]);
+ ASSERT_EQ(h2, r[2]);
+ ASSERT_EQ(h1, r[3]);
+}
+
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 718a2d798c..99b1337b11 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -90,8 +90,6 @@ cc_library(
":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 0187c09d7b..5709ac3067 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -855,6 +855,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
// Second, construct the slice instruction to perform the negative padding.
std::vector<int64> start_indices;
std::vector<int64> end_indices;
+ std::vector<int64> strides;
for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
const PaddingConfig::PaddingConfigDimension& padding_dimension =
pad->padding_config().dimensions(i);
@@ -868,16 +869,18 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
}
start_indices.push_back(start);
end_indices.push_back(end);
+ strides.push_back(1);
}
// Verify that the slice shape matches the pad shape.
TF_ASSIGN_OR_RETURN(Shape inferred_slice_shape,
ShapeInference::InferSliceShape(
- nonzero_pad_shape, start_indices, end_indices));
+ nonzero_pad_shape, start_indices, end_indices,
+ strides));
TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape()));
std::unique_ptr<HloInstruction> slice = HloInstruction::CreateSlice(
- pad->shape(), nonzero_pad, start_indices, end_indices);
+ pad->shape(), nonzero_pad, start_indices, end_indices, strides);
return ReplaceWithNewInstruction(pad, std::move(slice));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 0792006ddb..7e52c8fb0c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -520,7 +520,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}));
+ ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
@@ -551,7 +551,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}));
+ ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, empty_slice}, 0));
@@ -1132,7 +1132,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
- /*limit_indices=*/{dim0, dim1}));
+ /*limit_indices=*/{dim0, dim1}, /*slices=*/{1, 1}));
HloModule module(TestName());
HloComputation* computation = module.AddEntryComputation(builder.Build());
@@ -1537,7 +1537,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
- slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}));
+ slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index c498b86dd4..56568fd446 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -731,7 +731,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
@@ -763,7 +763,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
@@ -800,7 +800,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
auto tuple_element = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}));
+ HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
@@ -835,7 +835,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
// Slice output is 10 elements.
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
// Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
@@ -867,7 +867,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
// Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
@@ -904,7 +904,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
// Slice output is 10 elements.
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
// Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index a31e9b1782..a5f7cc0aeb 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -588,7 +588,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'.
slice = builder.AddInstruction(
- HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}));
+ HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice));
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index dd00c58240..0a1911cbd1 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -55,7 +55,7 @@ class CompileOnlyService : public Service {
// Override Service methods that require or imply the existence of an
// execute backend. Note that this does not include TransferToClient, as
- // computing contants produces global data that we may wish to transfer.
+ // computing constants produces global data that we may wish to transfer.
tensorflow::Status Execute(const ExecuteRequest* arg,
ExecuteResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index cdf277581f..cdfa30dd9a 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -49,17 +49,18 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
return Status::OK();
}
-/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize(
- const DeviceAssignmentProto& proto) {
+/* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
+DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
- DeviceAssignment assignment(proto.replica_count(), proto.computation_count());
+ auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(),
+ proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
TF_RET_CHECK(computation_device.replica_device_ids_size() ==
proto.replica_count());
for (int replica = 0; replica < proto.replica_count(); ++replica) {
- assignment(replica, computation) =
+ (*assignment)(replica, computation) =
computation_device.replica_device_ids(replica);
}
}
diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h
index 4d26d6bb85..7d9abcd100 100644
--- a/tensorflow/compiler/xla/service/computation_placer.h
+++ b/tensorflow/compiler/xla/service/computation_placer.h
@@ -49,7 +49,11 @@ class DeviceAssignment : public Array2D<int> {
// Protocol buffer serialization and deserialization.
Status Serialize(DeviceAssignmentProto* proto) const;
- static StatusOr<DeviceAssignment> Deserialize(
+
+ // Return a std::unique_ptr<DeviceAssignment> instead of a DeviceAssignment
+ // directly because one of the supported TF platforms (mac) does not compile
+ // due to a StatusOr of an incomplete type (DeviceAssignment).
+ static StatusOr<std::unique_ptr<DeviceAssignment>> Deserialize(
const DeviceAssignmentProto& proto);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index da8d983e1a..759d27e1f3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -359,7 +359,6 @@ Status AppendIRToFile(const string& file_name, const string& ir_module_string) {
StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
se::StreamExecutor* stream_exec) {
- VLOG(1) << "Compiling: " << module->name();
TF_RET_CHECK(stream_exec != nullptr);
std::call_once(llvm_command_line_options_initialized,
&InitializeLLVMCommandLineOptions, module->config());
@@ -404,8 +403,6 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
module->config().debug_options().xla_dump_debug_json_to();
if (CpuParallelBackendRequested(module->config())) {
- VLOG(1) << "Using parallel cpu backend";
-
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
// DependencyHloOrdering is used for the parallel emitter because the order
@@ -500,8 +497,6 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
.set_ir_module_string(ir_module_string);
}
} else {
- VLOG(1) << "Using sequential cpu backend";
-
// Select an order for emitting the HLO instructions for each
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
@@ -567,7 +562,6 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
}
}
- VLOG(1) << "Compilation finished";
return std::move(cpu_executable);
}
@@ -669,7 +663,6 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::vector<std::unique_ptr<AotCompilationResult>> results;
for (size_t i = 0; i < modules.size(); ++i) {
HloModule* module = modules[i].get();
- VLOG(1) << "Compiling ahead-of-time: " << module->name();
TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo));
@@ -748,8 +741,6 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::move(object_file_data), std::move(buffer_sizes),
result_slice.index()));
}
-
- VLOG(1) << "Compilation finished";
return std::move(results);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 5b21ae3d2a..db0a8b36cd 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -949,9 +949,20 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index(index.size());
for (int i = 0; i < index.size(); ++i) {
- sliced_index[i] = ir_builder_->CreateAdd(
- index[i], llvm::ConstantInt::get(index[i]->getType(),
- hlo->slice_starts(i)));
+ int64 stride = hlo->slice_stride(i);
+ if (stride != 1) {
+ sliced_index[i] = ir_builder_->CreateAdd(
+ ir_builder_->CreateMul(
+ index[i], llvm::ConstantInt::get(index[i]->getType(),
+ stride)),
+ llvm::ConstantInt::get(index[i]->getType(),
+ hlo->slice_starts(i)));
+ } else {
+ sliced_index[i] = ir_builder_->CreateAdd(
+ index[i],
+ llvm::ConstantInt::get(index[i]->getType(),
+ hlo->slice_starts(i)));
+ }
}
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 4e130de311..b8c6162084 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -80,6 +80,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
std::vector<int64> limit_indices(input->shape().dimensions().begin(),
input->shape().dimensions().end());
+ std::vector<int64> strides(input->shape().dimensions_size(), 1);
for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
int64 dim = conv_dnums.spatial_dimensions(i);
// If dimension "dim" has negative padding, increase the start index or
@@ -92,9 +93,9 @@ HloInstruction* MaybePaddedAndSlicedInput(
input = computation->AddInstruction(HloInstruction::CreateSlice(
ShapeInference::InferSliceShape(input->shape(), start_indices,
- limit_indices)
+ limit_indices, strides)
.ConsumeValueOrDie(),
- input, start_indices, limit_indices));
+ input, start_indices, limit_indices, strides));
}
return input;
@@ -354,6 +355,8 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
std::vector<int64> limit_indices(
new_backward_conv->shape().dimensions().begin(),
new_backward_conv->shape().dimensions().end());
+ std::vector<int64> strides(new_backward_conv->shape().dimensions_size(),
+ 1LL);
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
int64 padding_low = backward_conv->window().dimensions(i).padding_low();
int64 padding_high = backward_conv->window().dimensions(i).padding_high();
@@ -373,13 +376,13 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
// Replace the old backward convolution with the slice.
CHECK(ShapeUtil::Compatible(
ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
- limit_indices)
+ limit_indices, strides)
.ConsumeValueOrDie(),
backward_conv->shape()));
TF_CHECK_OK(computation->ReplaceWithNewInstruction(
backward_conv,
HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv,
- start_indices, limit_indices)));
+ start_indices, limit_indices, strides)));
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index a643bc4076..1c60b06ddd 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -147,6 +147,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9};
+ const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSIGN_OR_ASSERT_OK(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
@@ -154,7 +155,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction(HloInstruction::CreateSlice(
- shape, literal_instruction, slice_start, slice_limits));
+ shape, literal_instruction, slice_start, slice_limits, slice_strides));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 99b73dea29..9117ab9653 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -306,11 +306,13 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) {
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
instruction->AppendOperand(operand);
instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
+ instruction->slice_strides_.assign(strides.begin(), strides.end());
return instruction;
}
@@ -852,7 +854,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateReshape(shape, new_operands[0]);
case HloOpcode::kSlice:
CHECK_EQ(new_operands.size(), 1);
- return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_);
+ return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
+ slice_strides_);
case HloOpcode::kDynamicSlice:
return CreateDynamicSlice(shape, new_operands[0], new_operands[1],
dynamic_slice_sizes_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 37cbb0b769..d29c0935fc 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -174,7 +174,8 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices);
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specfied in
@@ -662,6 +663,15 @@ class HloInstruction {
return slice_limits_;
}
+ // Returns the stride in the given dimension for a slice node.
+ //
+ // Precondition: opcode() == HloOpcode::kSlice
+ int64 slice_stride(int64 dimension) const {
+ CHECK_EQ(HloOpcode::kSlice, opcode_);
+ return slice_strides_[dimension];
+ }
+ const std::vector<int64>& slice_strides() const { return slice_strides_; }
+
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -907,6 +917,7 @@ class HloInstruction {
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
+ std::vector<int64> slice_strides_;
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 8a1e705711..1a861cd16b 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -67,7 +67,8 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0));
auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
vec1_shape_, concat_1, /*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
/*dimension=*/0));
@@ -75,7 +76,8 @@ class HloRematerializationTest : public HloTestBase {
// which is necessary to use this computation in a while.
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
/*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
return builder.Build();
}
@@ -103,7 +105,8 @@ class HloRematerializationTest : public HloTestBase {
HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
auto slice_1 = builder.AddInstruction(
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
vec1_shape_, while_cond, while_body, slice_1));
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
@@ -111,7 +114,8 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
/*start_indices=*/{0},
- /*limit_indices=*/{1}));
+ /*limit_indices=*/{1},
+ /*strides=*/{1}));
return builder.Build();
}
@@ -353,7 +357,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
- /*limit_indices=*/{1024}));
+ /*limit_indices=*/{1024}, /*slices=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
@@ -469,7 +473,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
- /*limit_indices=*/{1024}));
+ /*limit_indices=*/{1024}, /*slices=*/{1}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index e348511c62..bcc9418d59 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -356,9 +356,26 @@ void EmitLogging(const char* tag, llvm::Value* value,
void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape,
bool is_pointer_to) {
- // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code,
- // most likely because the generated metadata is incorrect. Disable TBAA
- // metadata while we resolve this.
+ llvm::MDBuilder metadata_builder(instruction->getContext());
+ llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA");
+ string type_name;
+ if (is_pointer_to) {
+ type_name += "pointer-to ";
+ }
+ // Scalars do not have layout which makes it permissible to omit an explicit
+ // layout. To make sure that equivalent scalar shapes have the same TBAA,
+ // remove the (meaningless) explicit layout if one is present.
+ if (!ShapeUtil::IsArray(shape) || ShapeUtil::IsScalar(shape)) {
+ LayoutUtil::ClearLayout(&shape);
+ } else {
+ CHECK(shape.has_layout());
+ }
+ type_name += shape.ShortDebugString();
+ llvm::MDNode* tbaa_node =
+ metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root);
+ instruction->setMetadata(llvm::LLVMContext::MD_tbaa,
+ metadata_builder.createTBAAStructTagNode(
+ tbaa_node, tbaa_node, /*Offset=*/0));
}
void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index b332709995..5e4df9ddd6 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1135,7 +1135,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits) {
+ tensorflow::gtl::ArraySlice<int64> limits,
+ tensorflow::gtl::ArraySlice<int64> strides) {
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
@@ -1158,13 +1159,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
int64 start_index = starts[dimension];
int64 limit_index = limits[dimension];
+ int64 stride = strides[dimension];
if (start_index < 0) {
return InvalidArgument("negative start index to slice: %lld",
start_index);
}
- if (limit_index < 0) {
- return InvalidArgument("negative limit index to slice: %lld",
- limit_index);
+ if (stride == 0) {
+ return InvalidArgument("Zero stride");
}
if (limit_index > arg.dimensions(dimension)) {
return InvalidArgument(
@@ -1172,18 +1173,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"size (%lld)",
limit_index, arg.dimensions(dimension));
}
- if (start_index > limit_index) {
- return InvalidArgument(
- "limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice",
- limit_index, start_index);
- }
VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
start_index);
VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
limit_index);
-
- sizes.push_back(limits[dimension] - starts[dimension]);
+ if (stride > 0) {
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice with positive stride",
+ limit_index, start_index);
+ }
+ sizes.push_back((limit_index - start_index + stride - 1) / stride);
+ } else {
+ return InvalidArgument("Negative strides not supported");
+ }
}
return ShapeUtil::MakeShape(arg.element_type(), sizes);
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 55c60e149d..42e4c7d39d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -116,7 +116,8 @@ class ShapeInference {
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
static StatusOr<Shape> InferSliceShape(
const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits);
+ tensorflow::gtl::ArraySlice<int64> limits,
+ tensorflow::gtl::ArraySlice<int64> strides);
// Infers the shape produced by a dynamic slice operation of size specified
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 7cff042a48..8c731ae297 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -682,16 +682,43 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64});
+ ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
ASSERT_IS_OK(inferred_status.status());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
}
+TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
+}
+
+TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
+}
+
+TEST_F(ShapeInferenceTest, InferInvalidStride) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
+ inferred_status.status().code());
+}
+
TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2});
+ ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
ASSERT_FALSE(inferred_status.ok());
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
inferred_status.status().code());
@@ -700,7 +727,7 @@ TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
auto inferred_status =
- ShapeInference::InferSliceShape(vector_shape, {2}, {4});
+ ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
ASSERT_TRUE(inferred_status.ok());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index d25e5adee3..cd79e63caf 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -584,7 +584,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
if (add_additional_gte0_user) {
// Create 'slice' as an additional user of 'input'.
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(update_shape, input, {0}, {3}));
+ HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
// Modify 'update' to take 'slice' output.
update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice));
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 1f6e789379..92b8c7bb21 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -744,7 +744,8 @@ StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction(
Shape new_shape,
ShapeInference::InferSliceShape(
operand->output_shape(), AsInt64Slice(slice_request.start_indices()),
- AsInt64Slice(slice_request.limit_indices())));
+ AsInt64Slice(slice_request.limit_indices()),
+ AsInt64Slice(slice_request.stride())));
ComputationDataHandle handle = CreateComputationDataHandle();
@@ -2393,7 +2394,8 @@ void ComputationLowerer::Visit(
hlo_instruction = add_instruction(HloInstruction::CreateSlice(
request.output_shape(), operand,
AsInt64Slice(slice_request.start_indices()),
- AsInt64Slice(slice_request.limit_indices())));
+ AsInt64Slice(slice_request.limit_indices()),
+ AsInt64Slice(slice_request.stride())));
break;
}
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index bb7fbad000..024988743c 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -1853,7 +1853,7 @@ TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
auto x = builder.Parameter(0, x_literal->shape(), "x");
auto y = builder.Parameter(1, y_literal->shape(), "y");
- auto slice = builder.Slice(x, {1}, {2});
+ auto slice = builder.Slice(x, {1}, {2}, {1});
builder.Sub(slice, y);
ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 7abef6a27b..63a630f9e5 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -365,9 +365,9 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) {
std::vector<xla::ComputationDataHandle> out_slices;
for (int i = 0; i < 4; ++i) {
// Slice off individual matrices and reshape to 2D tensors.
- auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2});
+ auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
- auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2});
+ auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
auto out = builder.Dot(x_slice, y_slice);
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index c8b91eafc7..7803d234fd 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -210,7 +210,7 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
HloOpcode::kSelect, const10, add8, const9));
auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}));
+ ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
// CreateFusionInstruction needs the `instructions_to_fuse` argument in
// reverse topological order, so the first element in `instructions_to_fuse`
// must be the root.
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
index df3d4fa21d..56c15e5ff7 100644
--- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -36,7 +36,7 @@ XLA_TEST_F(SliceTest, Slice2D) {
ComputationBuilder builder(client_, "slice_2d");
auto original = builder.ConstantR2<float>(
{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
- builder.Slice(original, {2, 1}, {4, 3});
+ builder.Slice(original, {2, 1}, {4, 3}, {1, 1});
Array2D<float> expected({{8.0f, 9.0f}, {11.0f, 12.0f}});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -47,7 +47,7 @@ XLA_TEST_F(SliceTest, Slice3D) {
Array3D<float> array_3d(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
auto original = builder.ConstantR3FromArray3D<float>(array_3d);
- builder.Slice(original, {0, 0, 1}, {2, 1, 2});
+ builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1});
Array3D<float> expected_3d({{{2.0f}}, {{6.0f}}});
ComputeAndCompareR3<float>(&builder, expected_3d, {}, ErrorSpec(0.000001));
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index 2065e9e813..a7692fceb4 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -325,7 +325,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
ComputationBuilder builder(client_, TestName());
auto input = builder.Parameter(0, original, "input");
// Use the slice operator to get an off-diagonal element.
- builder.Slice(input, {0, 1}, {1, 2});
+ builder.Slice(input, {0, 1}, {1, 2}, {1, 1});
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 97120df0c5..5e7d475662 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -44,7 +44,7 @@ class SliceTest : public ClientLibraryTestBase {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<NativeT>(constant);
- builder.Slice(original, {2}, {4});
+ builder.Slice(original, {2}, {4}, {1});
const std::vector<NativeT> expected = {static_cast<NativeT>(2),
static_cast<NativeT>(3)};
@@ -55,7 +55,7 @@ class SliceTest : public ClientLibraryTestBase {
XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>({});
- builder.Slice(original, {0}, {0});
+ builder.Slice(original, {0}, {0}, {1});
ComputeAndCompareR1<float>(&builder, {}, {});
}
@@ -64,7 +64,7 @@ XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
ComputationBuilder builder(client_, TestName());
std::vector<float> constant(10, 0.3);
auto original = builder.ConstantR1<float>(constant);
- builder.Slice(original, {7}, {7});
+ builder.Slice(original, {7}, {7}, {1});
ComputeAndCompareR1<float>(&builder, {}, {});
}
@@ -87,7 +87,7 @@ TEST_F(SliceTest, SliceTenToTen) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {0}, {10});
+ builder.Slice(original, {0}, {10}, {1});
ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
}
@@ -98,7 +98,7 @@ TEST_F(SliceTest, SliceLastFourOf1024) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {1024 - 4}, {1024});
+ builder.Slice(original, {1024 - 4}, {1024}, {1});
const std::vector<float> expected = {1020, 1021, 1022, 1023};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -112,7 +112,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {7}, {7 + 1024});
+ builder.Slice(original, {7}, {7 + 1024}, {1});
std::vector<float> expected(1024);
std::iota(values.begin(), values.end(), 7.0);
@@ -122,7 +122,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
- builder.Slice(original, {0, 0}, {0, 0});
+ builder.Slice(original, {0, 0}, {0, 0}, {1, 1});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
}
@@ -130,7 +130,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
- builder.Slice(original, {0, 15}, {0, 20});
+ builder.Slice(original, {0, 15}, {0, 20}, {1, 1});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
}
@@ -138,7 +138,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
- builder.Slice(original, {1, 0}, {3, 0});
+ builder.Slice(original, {1, 0}, {3, 0}, {1, 1});
ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
}
@@ -153,7 +153,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {128, 128}, {256, 256});
+ builder.Slice(original, {128, 128}, {256, 256}, {1, 1});
Array2D<float> expected(128, 128);
for (int row = 0; row < 128; ++row) {
@@ -171,7 +171,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 3072}, {1, 4096});
+ builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1});
Array2D<float> expected(1, 1024);
std::iota(expected.data(), expected.data() + 1024, 3072.0);
@@ -192,7 +192,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
}
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 0}, {16, 2});
+ builder.Slice(original, {0, 0}, {16, 2}, {1, 1});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
}
@@ -204,7 +204,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR4FromArray4D(values);
- builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128});
+ builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
}
@@ -213,6 +213,7 @@ struct R2Spec {
int64 input_dim1;
std::array<int64, 2> slice_starts;
std::array<int64, 2> slice_limits;
+ std::array<int64, 2> slice_strides;
Layout layout;
};
@@ -228,7 +229,7 @@ TEST_P(SliceR2Test, DoIt) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR2FromArray2D<int32>(input);
- builder.Slice(a, spec.slice_starts, spec.slice_limits);
+ builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
std::unique_ptr<Array2D<int32>> expected =
ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
@@ -239,19 +240,23 @@ TEST_P(SliceR2Test, DoIt) {
INSTANTIATE_TEST_CASE_P(
SliceR2TestInstantiation, SliceR2Test,
::testing::Values(
- R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})},
- R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})},
- R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})},
- R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})},
- R2Spec {256, 400, {{0, 300}}, {{256, 400}},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
+ LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {500, 400, {{111, 123}}, {{300, 257}},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
+ LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {500, 400, {{111, 123}}, {{300, 400}},
+ R2Spec {256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {384, 512, {{128, 256}}, {{256, 384}},
+ R2Spec {500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {357, 512, {{111, 256}}, {{301, 384}},
+ R2Spec {500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}},
LayoutUtil::MakeLayout({1, 0})}
)
);
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index ccd2a95658..afa7d871c0 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -666,7 +666,8 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
auto build_condition = [this, v6s32](int count) {
ComputationBuilder builder(client_, TestName());
auto prev = builder.Reshape(
- builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {});
+ builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
+ {});
builder.Gt(builder.ConstantR0<int32>(count), prev);
return builder.Build().ConsumeValueOrDie();
};
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 42d5c1d155..31f0c3147e 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -195,16 +195,24 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
// 2. permutation.size() == input.size().
template <template <typename...> class C, typename T>
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- C<T> input_) {
- tensorflow::gtl::ArraySlice<T> input(input_);
- CHECK(IsPermutation(permutation, input.size()));
- std::vector<T> output(input.size());
+ C<T> input) {
+ tensorflow::gtl::ArraySlice<T> data(input);
+ CHECK(IsPermutation(permutation, data.size()));
+ std::vector<T> output(data.size());
for (size_t i = 0; i < permutation.size(); ++i) {
- output[permutation[i]] = input[i];
+ output[permutation[i]] = data[i];
}
return output;
}
+// Override of the above that works around compile failures with gcc 7.1.1.
+// For details see https://github.com/tensorflow/tensorflow/issues/10843
+template <typename T>
+std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
+ const std::vector<T>& input) {
+ return Permute<std::vector, T>(permutation, input);
+}
+
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
tensorflow::gtl::ArraySlice<int64> input_permutation);
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 95c1f0995b..86c72b3449 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -200,7 +200,7 @@ message OpMetadata {
string op_name = 2;
// Indicate a file and line that this op is associated to in a user's program.
//
- // e.g. it could be be the file and line of user code that generated the op.
+ // e.g. it could be the file and line of user code that generated the op.
string source_file = 3;
int32 source_line = 4;
}
@@ -369,6 +369,7 @@ message SliceRequest {
ComputationDataHandle operand = 2;
repeated int64 start_indices = 3;
repeated int64 limit_indices = 4;
+ repeated int64 stride = 5;
}
message DynamicSliceRequest {