aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Roy Frostig <frostig@google.com>2017-12-15 10:38:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 10:42:45 -0800
commit75a91cf3be635af4f6004f20f3c3cc50c37d3145 (patch)
tree32ca27cf5fdf175afe790c74b2d1e296dc35e3ff
parent22fe6558a958c6cc81d16d371031c06e262b1c83 (diff)
Python library and C++ bindings for creating and compiling local XLA computations.
PiperOrigin-RevId: 179211353
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/compiler/xla/BUILD10
-rw-r--r--tensorflow/compiler/xla/python/BUILD82
-rw-r--r--tensorflow/compiler/xla/python/__init__.py0
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc265
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h210
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i348
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc389
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h123
-rw-r--r--tensorflow/compiler/xla/python/xla.i18
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py605
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py898
-rw-r--r--tensorflow/tf_exported_symbols.lds1
-rw-r--r--tensorflow/tf_version_script.lds1
14 files changed, 2951 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index d80fe5c829..9437bef99f 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -411,6 +411,7 @@ filegroup(
"//tensorflow/compiler/xla/client:all_files",
"//tensorflow/compiler/xla/client/lib:all_files",
"//tensorflow/compiler/xla/legacy_flags:all_files",
+ "//tensorflow/compiler/xla/python:all_files",
"//tensorflow/compiler/xla/service:all_files",
"//tensorflow/compiler/xla/service/cpu:all_files",
"//tensorflow/compiler/xla/service/gpu:all_files",
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index d3f292207f..cd69c69889 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -20,6 +20,10 @@ package_group(
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_py",
+)
# Filegroup used to collect source files for dependency checking.
filegroup(
@@ -36,6 +40,12 @@ xla_proto_library(
visibility = ["//visibility:public"],
)
+tf_proto_library_py(
+ name = "xla_data_proto", # bzl adds a _py suffix
+ srcs = ["xla_data.proto"],
+ visibility = ["//visibility:public"],
+)
+
xla_proto_library(
name = "xla_proto",
srcs = ["xla.proto"],
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
new file mode 100644
index 0000000000..a6b8158671
--- /dev/null
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -0,0 +1,82 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
+
+py_library(
+ name = "xla_client",
+ srcs = ["xla_client.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pywrap_xla",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ ],
+)
+
+py_test(
+ name = "xla_client_test",
+ srcs = ["xla_client_test.py"],
+ main = "xla_client_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":xla_client",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cc_library(
+ name = "numpy_bridge",
+ srcs = ["numpy_bridge.cc"],
+ hdrs = ["numpy_bridge.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/python:numpy_lib",
+ ],
+)
+
+cc_library(
+ name = "local_computation_builder",
+ srcs = ["local_computation_builder.cc"],
+ hdrs = ["local_computation_builder.h"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "pywrap_xla",
+ srcs = ["xla.i"],
+ swig_includes = [
+ "local_computation_builder.i",
+ ],
+ deps = [
+ ":local_computation_builder",
+ ":numpy_bridge",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/python/__init__.py b/tensorflow/compiler/xla/python/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/compiler/xla/python/__init__.py
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
new file mode 100644
index 0000000000..0b0a53fac7
--- /dev/null
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -0,0 +1,265 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+
+namespace swig {
+
+CompiledLocalComputation::CompiledLocalComputation(
+ std::unique_ptr<LocalExecutable> executable)
+ : executable_(std::move(executable)) {}
+
+std::unique_ptr<Literal> CompiledLocalComputation::Execute(
+ const std::vector<Literal>& arguments) {
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
+
+ // Transfer arguments in
+ std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
+ scoped_buffers.reserve(arguments.size());
+ for (const Literal& argument : arguments) {
+ scoped_buffers.push_back(
+ client
+ ->LiteralToShapedBuffer(argument,
+ /*device_ordinal=*/0,
+ client->backend().memory_allocator())
+ .ConsumeValueOrDie());
+ }
+
+ // Execute
+ std::vector<const ShapedBuffer*> argument_buffers;
+ argument_buffers.reserve(scoped_buffers.size());
+ for (auto& buffer : scoped_buffers) {
+ argument_buffers.push_back(buffer.get());
+ }
+ ExecutableRunOptions options;
+ options.set_allocator(client->backend().memory_allocator());
+ options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool());
+ options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ std::unique_ptr<ScopedShapedBuffer> result_buffer =
+ executable_->Run(argument_buffers, options).ConsumeValueOrDie();
+
+ // Transfer result out
+ return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie();
+}
+
+LocalComputation::LocalComputation(std::unique_ptr<Computation> computation)
+ : computation_(std::move(computation)) {}
+
+CompiledLocalComputation* LocalComputation::Compile(
+ const std::vector<Shape>& argument_shapes) {
+ std::vector<const Shape*> argument_shape_pointers;
+ argument_shape_pointers.reserve(argument_shapes.size());
+ for (auto& argument_shape : argument_shapes) {
+ argument_shape_pointers.push_back(&argument_shape);
+ }
+
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
+ ExecutableBuildOptions options;
+ return new CompiledLocalComputation(
+ client->Compile(*computation_, argument_shape_pointers, options)
+ .ValueOrDie());
+}
+
+const Computation& LocalComputation::computation() const {
+ return *computation_;
+}
+
+LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
+ : builder_(ClientLibrary::LocalClientOrDie(), computation_name) {}
+
+LocalComputation* LocalComputationBuilder::Build() {
+ return new LocalComputation(std::unique_ptr<Computation>(
+ new Computation(builder_.Build().ConsumeValueOrDie())));
+}
+
+ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number,
+ const Shape& shape,
+ const string& name) {
+ return builder_.Parameter(parameter_number, shape, name);
+}
+
+std::unique_ptr<Shape> LocalComputationBuilder::GetShape(
+ const ComputationDataHandle& operand) {
+ return builder_.GetShape(operand).ConsumeValueOrDie();
+}
+
+ComputationDataHandle LocalComputationBuilder::ConstantLiteral(
+ const Literal& literal) {
+ return builder_.ConstantLiteral(literal);
+}
+
+ComputationDataHandle LocalComputationBuilder::Broadcast(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ return builder_.Broadcast(operand, broadcast_sizes);
+}
+
+ComputationDataHandle LocalComputationBuilder::Reshape(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ return builder_.Reshape(operand, dimensions, new_sizes);
+}
+
+ComputationDataHandle LocalComputationBuilder::Slice(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides) {
+ return builder_.Slice(operand, start_indices, limit_indices, strides);
+}
+
+ComputationDataHandle LocalComputationBuilder::DynamicSlice(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return builder_.DynamicSlice(operand, start_indices, slice_sizes);
+}
+
+ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice(
+ const ComputationDataHandle& operand, const ComputationDataHandle& update,
+ const ComputationDataHandle& start_indices) {
+ return builder_.DynamicUpdateSlice(operand, update, start_indices);
+}
+
+ComputationDataHandle LocalComputationBuilder::ConcatInDim(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ int64 dimension) {
+ return builder_.ConcatInDim(operands, dimension);
+}
+
+ComputationDataHandle LocalComputationBuilder::Select(
+ const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
+ const ComputationDataHandle& on_false) {
+ return builder_.Select(pred, on_true, on_false);
+}
+
+ComputationDataHandle LocalComputationBuilder::Tuple(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
+ return builder_.Tuple(elements);
+}
+
+ComputationDataHandle LocalComputationBuilder::GetTupleElement(
+ const ComputationDataHandle& tuple_data, int64 index) {
+ return builder_.GetTupleElement(tuple_data, index);
+}
+
+ComputationDataHandle LocalComputationBuilder::Dot(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
+ return builder_.Dot(lhs, rhs);
+}
+
+ComputationDataHandle LocalComputationBuilder::ConvertElementType(
+ const ComputationDataHandle& operand, PrimitiveType new_element_type) {
+ return builder_.ConvertElementType(operand, new_element_type);
+}
+
+ComputationDataHandle LocalComputationBuilder::Call(
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
+ return builder_.Call(local_computation.computation(), operands);
+}
+
+ComputationDataHandle LocalComputationBuilder::Transpose(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation) {
+ return builder_.Transpose(operand, permutation);
+}
+
+ComputationDataHandle LocalComputationBuilder::Map(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
+ return builder_.Map(operands, local_computation.computation(), dimensions,
+ static_operands);
+}
+
+ComputationDataHandle LocalComputationBuilder::Reduce(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value,
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ return builder_.Reduce(operand, init_value, local_computation.computation(),
+ dimensions_to_reduce);
+}
+
+ComputationDataHandle LocalComputationBuilder::While(
+ const LocalComputation& condition, const LocalComputation& body,
+ const ComputationDataHandle& init) {
+ return builder_.While(condition.computation(), body.computation(), init);
+}
+
+#define _FORWARD(method_name, return_sig, args_sig, args) \
+ return_sig LocalComputationBuilder::method_name args_sig { \
+ return builder_.method_name args; \
+ }
+
+#define _FORWARD_UNOP(method_name) \
+ _FORWARD(method_name, ComputationDataHandle, \
+ (const ComputationDataHandle& operand), (operand))
+
+#define _FORWARD_BINOP(method_name) \
+ _FORWARD( \
+ method_name, ComputationDataHandle, \
+ (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
+ (lhs, rhs, broadcast_dimensions))
+
+_FORWARD_BINOP(Eq)
+_FORWARD_BINOP(Ne)
+_FORWARD_BINOP(Ge)
+_FORWARD_BINOP(Gt)
+_FORWARD_BINOP(Lt)
+_FORWARD_BINOP(Le)
+_FORWARD_BINOP(Add)
+_FORWARD_BINOP(Sub)
+_FORWARD_BINOP(Mul)
+_FORWARD_BINOP(Div)
+_FORWARD_BINOP(Rem)
+_FORWARD_BINOP(Max)
+_FORWARD_BINOP(Min)
+_FORWARD_BINOP(And)
+_FORWARD_BINOP(Or)
+_FORWARD_UNOP(Not)
+_FORWARD_UNOP(Abs)
+_FORWARD_UNOP(Exp)
+_FORWARD_UNOP(Floor)
+_FORWARD_UNOP(Ceil)
+_FORWARD_UNOP(Log)
+_FORWARD_UNOP(Sign)
+_FORWARD_UNOP(Cos)
+_FORWARD_UNOP(Sin)
+_FORWARD_UNOP(Tanh)
+_FORWARD_UNOP(SqrtF32)
+_FORWARD_UNOP(SquareF32)
+_FORWARD_BINOP(Pow)
+_FORWARD_UNOP(IsFinite)
+_FORWARD_UNOP(ReciprocalF32)
+_FORWARD_UNOP(Neg)
+_FORWARD_UNOP(Sort)
+
+#undef _FORWARD
+#undef _FORWARD_UNOP
+#undef _FORWARD_BINOP
+
+} // namespace swig
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
new file mode 100644
index 0000000000..cbab45a5f0
--- /dev/null
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+
+namespace swig {
+
+// Wraps a LocalExecutable produced by compiling a
+// LocalComputation. The Execute method forwards to that of the
+// underlying LocalExecutable, and additionally handles tranferring
+// arguments and return values in and back out of the client library's
+// local client. This class is intended to be made available to Python
+// via SWIG.
+class CompiledLocalComputation {
+ public:
+ CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
+ std::unique_ptr<Literal> Execute(const std::vector<Literal>& arguments);
+
+ private:
+ std::unique_ptr<LocalExecutable> executable_;
+};
+
+// Wraps a Computation produced by a LocalComputationBuilder. The
+// Compile method compiles the computation to a (local) executable via
+// the client library's local client. This class is intended to be
+// made available to Python via SWIG.
+class LocalComputation {
+ public:
+ LocalComputation(std::unique_ptr<Computation> computation);
+ CompiledLocalComputation* Compile(const std::vector<Shape>& argument_shapes);
+ const Computation& computation() const;
+
+ private:
+ std::unique_ptr<Computation> computation_;
+};
+
+// Wraps the ComputationBuilder API in order to:
+// - Support consumption by SWIG in order to be made available to
+// Python.
+// - Set up the underlying builder to use the client library's
+// LocalClient.
+// - Wrap Computations in LocalComputations for Python access.
+// - Correspondingly unwrap incoming LocalComputations.
+class LocalComputationBuilder {
+ public:
+ LocalComputationBuilder(const string& computation_name);
+
+ LocalComputation* Build();
+
+ ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
+ const string& name);
+
+ std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);
+
+ ComputationDataHandle ConstantLiteral(const Literal& literal);
+
+ ComputationDataHandle Broadcast(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ ComputationDataHandle Reshape(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ ComputationDataHandle Slice(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ ComputationDataHandle DynamicSlice(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ ComputationDataHandle DynamicUpdateSlice(
+ const ComputationDataHandle& operand, const ComputationDataHandle& update,
+ const ComputationDataHandle& start_indices);
+
+ ComputationDataHandle ConcatInDim(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ int64 dimension);
+
+ ComputationDataHandle Select(const ComputationDataHandle& pred,
+ const ComputationDataHandle& on_true,
+ const ComputationDataHandle& on_false);
+
+ ComputationDataHandle Tuple(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
+
+ ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
+ int64 index);
+
+ ComputationDataHandle Dot(const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs);
+
+ ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
+ PrimitiveType new_element_type);
+
+ ComputationDataHandle Call(
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
+
+ ComputationDataHandle Transpose(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+
+ ComputationDataHandle Map(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands);
+
+ ComputationDataHandle Reduce(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value,
+ const LocalComputation& local_computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
+ ComputationDataHandle While(const LocalComputation& condition,
+ const LocalComputation& body,
+ const ComputationDataHandle& init);
+
+#define _FORWARD(method_name, return_sig, args_sig) \
+ return_sig method_name args_sig;
+
+#define _FORWARD_UNOP(method_name) \
+ _FORWARD(method_name, ComputationDataHandle, \
+ (const ComputationDataHandle& operand))
+
+#define _FORWARD_BINOP(method_name) \
+ _FORWARD( \
+ method_name, ComputationDataHandle, \
+ (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
+
+ _FORWARD_BINOP(Eq)
+ _FORWARD_BINOP(Ne)
+ _FORWARD_BINOP(Ge)
+ _FORWARD_BINOP(Gt)
+ _FORWARD_BINOP(Lt)
+ _FORWARD_BINOP(Le)
+ _FORWARD_BINOP(Add)
+ _FORWARD_BINOP(Sub)
+ _FORWARD_BINOP(Mul)
+ _FORWARD_BINOP(Div)
+ _FORWARD_BINOP(Rem)
+ _FORWARD_BINOP(Max)
+ _FORWARD_BINOP(Min)
+ _FORWARD_BINOP(And)
+ _FORWARD_BINOP(Or)
+ _FORWARD_UNOP(Not)
+ _FORWARD_UNOP(Abs)
+ _FORWARD_UNOP(Exp)
+ _FORWARD_UNOP(Floor)
+ _FORWARD_UNOP(Ceil)
+ _FORWARD_UNOP(Log)
+ _FORWARD_UNOP(Sign)
+ _FORWARD_UNOP(Cos)
+ _FORWARD_UNOP(Sin)
+ _FORWARD_UNOP(Tanh)
+ _FORWARD_UNOP(SqrtF32)
+ _FORWARD_UNOP(SquareF32)
+ _FORWARD_BINOP(Pow)
+ _FORWARD_UNOP(IsFinite)
+ _FORWARD_UNOP(ReciprocalF32)
+ _FORWARD_UNOP(Neg)
+ _FORWARD_UNOP(Sort)
+
+#undef _FORWARD
+#undef _FORWARD_UNOP
+#undef _FORWARD_BINOP
+
+ private:
+ ComputationBuilder builder_;
+};
+
+static void DeleteLocalComputation(LocalComputation* computation) {
+ delete computation;
+}
+
+static void DeleteCompiledLocalComputation(
+ CompiledLocalComputation* computation) {
+ delete computation;
+}
+
+} // namespace swig
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
new file mode 100644
index 0000000000..ac8f3e4277
--- /dev/null
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -0,0 +1,348 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// SWIG typemaps and declarations for building, compiling, and
+// executing XLA computations, wrapping most of what is declared in
+// local_computation_builder.h.
+//
+// The typemaps below implement/assert the following correspondences
+// (with elaborations below):
+//
+// C++ Python
+// -------------------------------------+---------------------------------------
+// ComputationDataHandle <-> long
+// ArraySlice<int64> <- sequence of long
+// ArraySlice<ComputationDataHandle> <- sequence of long
+// Literal <-> (nested tuple of) numpy ndarray
+// std::vector<Literal> <- sequence of (nested tuple of) ndarray
+// Shape <-> pair holding (dtype, dimensions)
+// std::vector<Shape> <- sequence of shape information pairs
+// PrimitiveType <- int
+//
+// Arrows indicate whether a conversion only ever occurs in one
+// direction, or whether it is maintained bidirectionally. Also,
+// "long" and "int" denote the Python types so named, not C.
+//
+// The Python objects corresponding to C++ Literals have the type:
+//
+// T = ndarray | (T, ...)
+//
+// where a terminal numpy ndarray translates to a Literal with a
+// non-tuple Shape, an XLA primitive element type corresponding to the
+// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
+// to a tuple-shaped Literal whose tuple components are translated
+// recursively. For example, if x is a numpy ndarray in Python, with
+// shape (2, 3) and dtype of dtype('float32'), then x translates to a
+// Literal with rank 2, dimension 2 and 3, and XLA primitive type
+// F32. Meanwhile,
+//
+// (x, (x, x), (x,)),
+//
+// translates to a tuple-shaped XLA Literal, whose component subshapes
+// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
+//
+// The Python objects corresponding to C++ Shapes have the type:
+//
+// T = (dtype, S)
+// S = DIMENSIONS | TUPLE_SHAPES
+// DIMENSIONS = (int, ...)
+// TUPLE_SHAPES = (T, ...)
+//
+// In the pair described by the T rule, the terminal dtype determines
+// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
+// dtype('O'), numpy's object dtype, the structure represents a tuple
+// shape and the expansion of the non-terminal S is
+// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
+// and S expands into DIMENSIONS giving dimension sizes. For example:
+//
+// (dtype('float32'), (3, 5, 7))
+//
+// describes a 3x5x7 array of F32s, and
+//
+// (dtype('O'), ((dtype('float32'), (2, 3)),
+// (dtype('float64'), (4, 5))))
+//
+// describes a tuple shape with two subshapes: the first a 2x3 F32,
+// and the other a 4x5 F64.
+//
+// The Python int corresponding to a PrimitiveType enum must be valid
+// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
+//
+// The SWIG object wrappers generated by this file are not intended
+// for end use, but rather for internal use in the Python XLA client,
+// xla_client.py.
+//
+// One central reason for the Python-side indirection is that the
+// Python-side objects produced by the typemaps in this file are
+// further packaged up by xla_client before being passed on. For
+// instance, xla_client wraps the long produced for a C++
+// ComputationDataHandle in a Python ComputationDataHandle proto,
+// rather than exposing a raw long outside of the client. Similarly,
+// the Python pair produced for a C++ Shape is further wrapped in a
+// Python class (xla_client.Shape) so as not to expose the raw pair
+// externally.
+//
+// Other SWIG object wrappers (e.g. of LocalComputation) are further
+// wrapped by xla_client in order to set up a custom destructor that
+// triggers memory deallocation on the C++ side.
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+// Must be included first
+#include "tensorflow/python/lib/core/numpy.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/compiler/xla/python/numpy_bridge.h"
+#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+
+using namespace xla;
+using namespace xla::swig;
+%}
+
+// Required to use PyArray_* functions.
+%init %{
+tensorflow::ImportNumpy();
+%}
+
+// ComputationDataHandle
+
+%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
+ const int64 handle = numpy::PyIntOrPyLongToLong($input);
+ if (handle == -1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ temp.set_handle(handle);
+ $1 = &temp;
+}
+
+%typemap(out) ComputationDataHandle {
+ $result = numpy::LongToPyIntOrPyLong($1.handle());
+}
+
+// ArraySlice<int64>
+
+%typemap(in) tensorflow::gtl::ArraySlice<int64>
+ (std::vector<int64> temps) {
+ if (!PySequence_Check($input)) {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
+ return NULL;
+ }
+ const int size = PySequence_Size($input);
+ temps.resize(size);
+ for (int i = 0; i < size; ++i) {
+ PyObject* o = PySequence_GetItem($input, i);
+ PyObject* py_int = numpy::PyNumberToPyInt(o);
+ if (!py_int) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Argument sequence element cannot be converted to int");
+ Py_DECREF(o);
+ return NULL;
+ }
+ temps[i] = numpy::PyIntOrPyLongToLong(py_int);
+ if (temps[i] == -1 && PyErr_Occurred()) {
+ Py_DECREF(py_int);
+ Py_DECREF(o);
+ return NULL;
+ }
+ Py_DECREF(py_int);
+ Py_DECREF(o);
+ }
+ $1 = temps;
+}
+
+// ComputationDataHandle
+
+%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle>
+ (std::vector<ComputationDataHandle> temps) {
+ if (!PySequence_Check($input)) {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
+ return NULL;
+ }
+ const int size = PySequence_Size($input);
+ temps.resize(size);
+ for (int i = 0; i < size; ++i) {
+ PyObject* o = PySequence_GetItem($input, i);
+ PyObject* py_int = numpy::PyNumberToPyInt(o);
+ if (!py_int) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Argument sequence element cannot be converted to int");
+ return NULL;
+ }
+ const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
+ if (handle == -1 && PyErr_Occurred()) {
+ Py_DECREF(py_int);
+ Py_DECREF(o);
+ return NULL;
+ }
+ temps[i].set_handle(handle);
+ Py_DECREF(py_int);
+ Py_DECREF(o);
+ }
+ $1 = temps;
+}
+
+// Literal
+
+%typemap(in) const Literal& (std::unique_ptr<Literal> temp) {
+ temp = numpy::XlaLiteralFromPyObject($input);
+ $1 = &*temp;
+}
+
+%typemap(out) std::unique_ptr<Literal> {
+ $result = numpy::PyObjectFromXlaLiteral(*$1);
+}
+
+%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
+ if (!PySequence_Check($input)) {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
+ return NULL;
+ }
+ const int size = PySequence_Size($input);
+ for (int i = 0; i < size; ++i) {
+ PyObject* o = PySequence_GetItem($input, i);
+ temps.push_back(*numpy::XlaLiteralFromPyObject(o));
+ Py_DECREF(o);
+ }
+ $1 = &temps;
+}
+
+// Shape
+
+%typemap(in) const Shape& (Shape temp) {
+ if (!numpy::CheckPyShapeInfo($input)) {
+ return NULL;
+ }
+ temp = numpy::XlaShapeFromPyShapeInfo($input);
+ $1 = &temp;
+}
+
+%typemap(out) std::unique_ptr<Shape> {
+ $result = numpy::PyShapeInfoFromXlaShape(*$1);
+}
+
+%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
+ if (!PySequence_Check($input)) {
+ PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
+ return NULL;
+ }
+ const int size = PySequence_Size($input);
+ for (int i = 0; i < size; ++i) {
+ PyObject* o = PySequence_GetItem($input, i);
+ if (!numpy::CheckPyShapeInfo(o)) {
+ Py_DECREF(o);
+ return NULL;
+ }
+ temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
+ Py_DECREF(o);
+ }
+ $1 = &temps;
+}
+
+// PrimitiveType
+
+%typemap(in) PrimitiveType {
+ PyObject* py_int = numpy::PyNumberToPyInt($input);
+ if (!py_int) {
+ PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
+ return NULL;
+ }
+ const long value = numpy::PyIntOrPyLongToLong(py_int);
+ if (value == -1 && PyErr_Occurred()) {
+ Py_DECREF(py_int);
+ return NULL;
+ }
+ if (!PrimitiveType_IsValid(value)) {
+ PyErr_SetString(
+ PyExc_TypeError, "Argument not valid for PrimitiveType enum");
+ Py_DECREF(py_int);
+ return NULL;
+ }
+ $1 = static_cast<PrimitiveType>(value);
+}
+
+%ignoreall
+%unignore xla;
+%unignore xla::swig;
+%unignore xla::swig::CompiledLocalComputation;
+%unignore xla::swig::CompiledLocalComputation::Execute;
+%unignore xla::swig::LocalComputation;
+%unignore xla::swig::LocalComputation::Compile;
+%unignore xla::swig::LocalComputationBuilder;
+%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
+%unignore xla::swig::LocalComputationBuilder::Build;
+%unignore xla::swig::LocalComputationBuilder::Parameter;
+%unignore xla::swig::LocalComputationBuilder::GetShape;
+%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
+%unignore xla::swig::LocalComputationBuilder::ConstantR0;
+%unignore xla::swig::LocalComputationBuilder::Broadcast;
+%unignore xla::swig::LocalComputationBuilder::Reshape;
+%unignore xla::swig::LocalComputationBuilder::Slice;
+%unignore xla::swig::LocalComputationBuilder::DynamicSlice;
+%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice;
+%unignore xla::swig::LocalComputationBuilder::ConcatInDim;
+%unignore xla::swig::LocalComputationBuilder::Select;
+%unignore xla::swig::LocalComputationBuilder::Tuple;
+%unignore xla::swig::LocalComputationBuilder::GetTupleElement;
+%unignore xla::swig::LocalComputationBuilder::ConvertElementType;
+%unignore xla::swig::LocalComputationBuilder::Call;
+%unignore xla::swig::LocalComputationBuilder::Transpose;
+%unignore xla::swig::LocalComputationBuilder::Map;
+%unignore xla::swig::LocalComputationBuilder::Reduce;
+%unignore xla::swig::LocalComputationBuilder::While;
+%unignore xla::swig::LocalComputationBuilder::Eq;
+%unignore xla::swig::LocalComputationBuilder::Ne;
+%unignore xla::swig::LocalComputationBuilder::Ge;
+%unignore xla::swig::LocalComputationBuilder::Gt;
+%unignore xla::swig::LocalComputationBuilder::Lt;
+%unignore xla::swig::LocalComputationBuilder::Le;
+%unignore xla::swig::LocalComputationBuilder::Dot;
+%unignore xla::swig::LocalComputationBuilder::Add;
+%unignore xla::swig::LocalComputationBuilder::Sub;
+%unignore xla::swig::LocalComputationBuilder::Mul;
+%unignore xla::swig::LocalComputationBuilder::Div;
+%unignore xla::swig::LocalComputationBuilder::Rem;
+%unignore xla::swig::LocalComputationBuilder::Max;
+%unignore xla::swig::LocalComputationBuilder::Min;
+%unignore xla::swig::LocalComputationBuilder::And;
+%unignore xla::swig::LocalComputationBuilder::Or;
+%unignore xla::swig::LocalComputationBuilder::Not;
+%unignore xla::swig::LocalComputationBuilder::Abs;
+%unignore xla::swig::LocalComputationBuilder::Exp;
+%unignore xla::swig::LocalComputationBuilder::Floor;
+%unignore xla::swig::LocalComputationBuilder::Ceil;
+%unignore xla::swig::LocalComputationBuilder::Log;
+%unignore xla::swig::LocalComputationBuilder::Sign;
+%unignore xla::swig::LocalComputationBuilder::Cos;
+%unignore xla::swig::LocalComputationBuilder::Sin;
+%unignore xla::swig::LocalComputationBuilder::Tanh;
+%unignore xla::swig::LocalComputationBuilder::SqrtF32;
+%unignore xla::swig::LocalComputationBuilder::SquareF32;
+%unignore xla::swig::LocalComputationBuilder::Pow;
+%unignore xla::swig::LocalComputationBuilder::IsFinite;
+%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
+%unignore xla::swig::LocalComputationBuilder::Neg;
+%unignore xla::swig::LocalComputationBuilder::Sort;
+%unignore xla::swig::DeleteLocalComputation;
+%unignore xla::swig::DeleteCompiledLocalComputation;
+
+%include "tensorflow/compiler/xla/python/local_computation_builder.h"
+
+%unignoreall
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
new file mode 100644
index 0000000000..b30bdc3669
--- /dev/null
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -0,0 +1,389 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/python/numpy_bridge.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+namespace swig {
+
+namespace numpy {
+
+int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
+ switch (primitive_type) {
+ case PRED:
+ return NPY_BOOL;
+ case S8:
+ return NPY_INT8;
+ case S16:
+ return NPY_INT16;
+ case S32:
+ return NPY_INT32;
+ case S64:
+ return NPY_INT64;
+ case U8:
+ return NPY_UINT8;
+ case U16:
+ return NPY_UINT16;
+ case U32:
+ return NPY_UINT32;
+ case U64:
+ return NPY_UINT64;
+ case F16:
+ return NPY_FLOAT16;
+ case F32:
+ return NPY_FLOAT32;
+ case F64:
+ return NPY_FLOAT64;
+ case TUPLE:
+ return NPY_OBJECT;
+ default:
+ LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type;
+ }
+}
+
+PrimitiveType NumpyTypeToPrimitiveType(int np_type) {
+ switch (np_type) {
+ case NPY_BOOL:
+ return PRED;
+ case NPY_INT8:
+ return S8;
+ case NPY_INT16:
+ return S16;
+ case NPY_INT32:
+ return S32;
+ case NPY_INT64:
+ return S64;
+ case NPY_UINT8:
+ return U8;
+ case NPY_UINT16:
+ return U16;
+ case NPY_UINT32:
+ return U32;
+ case NPY_UINT64:
+ return U64;
+ case NPY_FLOAT16:
+ return F16;
+ case NPY_FLOAT32:
+ return F32;
+ case NPY_FLOAT64:
+ return F64;
+ case NPY_OBJECT:
+ return TUPLE;
+ default:
+ LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type;
+ }
+}
+
+bool NumpyTypeIsValid(int np_type) {
+ switch (np_type) {
+ case NPY_BOOL:
+ case NPY_INT8:
+ case NPY_INT16:
+ case NPY_INT32:
+ case NPY_INT64:
+ case NPY_UINT8:
+ case NPY_UINT16:
+ case NPY_UINT32:
+ case NPY_UINT64:
+ case NPY_FLOAT16:
+ case NPY_FLOAT32:
+ case NPY_FLOAT64:
+ case NPY_OBJECT:
+ return true;
+ default:
+ return false;
+ }
+}
+
+PyObject* PyShapeInfoFromXlaShape(const Shape& shape) {
+ int np_typenum = PrimitiveTypeToNumpyType(shape.element_type());
+ PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum);
+
+ PyObject* dimensions;
+ if (ShapeUtil::IsTuple(shape)) {
+ int num_elements = ShapeUtil::TupleElementCount(shape);
+ dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape));
+ for (int i = 0; i < num_elements; ++i) {
+ PyTuple_SET_ITEM(
+ dimensions, i,
+ PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)));
+ }
+ } else {
+ int rank = ShapeUtil::Rank(shape);
+ dimensions = PyTuple_New(rank);
+ for (int i = 0; i < rank; ++i) {
+ PyTuple_SET_ITEM(dimensions, i,
+ LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i)));
+ }
+ }
+ return PyTuple_Pack(2, np_dtype, dimensions);
+}
+
+// Precondition: o->ob_type == &PyArrayDescr_Type
+static int NumpyTypenum(PyObject* o) {
+ return reinterpret_cast<PyArray_Descr*>(o)->type_num;
+}
+
+bool CheckPyShapeInfo(PyObject* o) {
+ // The object is a tuple (a pair)
+ if (!PyTuple_Check(o)) {
+ PyErr_SetString(PyExc_TypeError, "Shape record must be a tuple");
+ return false;
+ }
+ if (PyTuple_Size(o) != 2) {
+ PyErr_SetString(PyExc_ValueError, "Shape record tuple must be of length 2");
+ return false;
+ }
+
+ // It has a first element, which is a numpy dtype object
+ PyObject* first = PyTuple_GetItem(o, 0);
+ if (!first) {
+ return false;
+ }
+ if (first->ob_type != &PyArrayDescr_Type) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Shape record does not have a numpy dtype as its first element");
+ return false;
+ }
+ const int np_type = NumpyTypenum(first);
+ if (!NumpyTypeIsValid(np_type)) {
+ PyErr_SetString(PyExc_ValueError,
+ "Shape record has an invalid integer dtype");
+ return false;
+ }
+
+ // It has a second element, which is a tuple, either of shape
+ // records or of Python ints
+ PyObject* second = PyTuple_GetItem(o, 1);
+ if (!second) {
+ return false;
+ }
+ if (!PyTuple_Check(second)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Shape record does not have a tuple as its second element");
+ return false;
+ }
+ const int length = PyTuple_Size(second);
+ const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
+ for (int i = 0; i < length; i++) {
+ PyObject* dimension = PyTuple_GetItem(second, i);
+ if (element_type == TUPLE) {
+ if (!CheckPyShapeInfo(dimension)) {
+ return false;
+ }
+ } else if (!CheckPyIntOrLong(dimension)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Non-tuple shape record has a non-integer dimension");
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// Precondition: CheckPyShapeInfo(o)
+Shape XlaShapeFromPyShapeInfo(PyObject* o) {
+ const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
+ const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
+ PyObject* py_dimensions = PyTuple_GetItem(o, 1);
+ const int length = PyTuple_Size(py_dimensions);
+ if (element_type == TUPLE) {
+ std::vector<Shape> subshapes;
+ subshapes.reserve(length);
+ for (int i = 0; i < length; i++) {
+ subshapes.push_back(
+ XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
+ }
+ return ShapeUtil::MakeTupleShape(subshapes);
+ } else {
+ std::vector<int64> dimensions(length);
+ for (int i = 0; i < length; i++) {
+ dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
+ if (dimensions[i] == -1) {
+ CHECK(!PyErr_Occurred());
+ }
+ }
+ return ShapeUtil::MakeShape(element_type, dimensions);
+ }
+}
+
+PyObject* PyObjectFromXlaLiteral(const Literal& literal) {
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ const std::vector<Literal>& tuple_literals = literal.tuple_literals();
+ int num_elements = ShapeUtil::TupleElementCount(literal.shape());
+ PyObject* tuple = PyTuple_New(num_elements);
+ for (int i = 0; i < num_elements; i++) {
+ PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i]));
+ }
+ return tuple;
+ } else {
+ int rank = ShapeUtil::Rank(literal.shape());
+ std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long*
+ for (int i = 0; i < rank; i++) {
+ dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
+ }
+ int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
+ PyObject* array =
+ PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
+ CopyLiteralToNumpyArray(np_type, literal,
+ reinterpret_cast<PyArrayObject*>(array));
+ return array;
+ }
+}
+
+std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o) {
+ if (PyTuple_Check(o)) {
+ int num_elements = PyTuple_Size(o);
+ std::vector<std::unique_ptr<Literal>> elements;
+ elements.reserve(num_elements);
+ for (int i = 0; i < num_elements; i++) {
+ PyObject* element = PyTuple_GetItem(o, i);
+ elements.push_back(XlaLiteralFromPyObject(element));
+ }
+ return Literal::MakeTupleOwned(std::move(elements));
+ } else if (PyArray_Check(o)) {
+ PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
+ int rank = PyArray_NDIM(py_array);
+ std::vector<int64> dimensions(rank);
+ for (int i = 0; i < rank; i++) {
+ dimensions[i] = PyArray_DIM(py_array, i);
+ }
+ int np_type = PyArray_TYPE(py_array);
+ auto literal = Literal::CreateFromDimensions(
+ NumpyTypeToPrimitiveType(np_type), dimensions);
+ CopyNumpyArrayToLiteral(np_type, py_array, literal.get());
+ return literal;
+ } else {
+ LOG(FATAL)
+ << "Non-tuple or Numpy array encountered in conversion to XLA literal";
+ }
+}
+
+void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
+ Literal* literal) {
+ switch (np_type) {
+ case NPY_BOOL:
+ CopyNumpyArrayToLiteral<bool>(py_array, literal);
+ break;
+ case NPY_INT32:
+ CopyNumpyArrayToLiteral<int32>(py_array, literal);
+ break;
+ case NPY_INT64:
+ CopyNumpyArrayToLiteral<int64>(py_array, literal);
+ break;
+ case NPY_UINT8:
+ CopyNumpyArrayToLiteral<uint8>(py_array, literal);
+ break;
+ case NPY_UINT32:
+ CopyNumpyArrayToLiteral<uint32>(py_array, literal);
+ break;
+ case NPY_UINT64:
+ CopyNumpyArrayToLiteral<uint64>(py_array, literal);
+ break;
+ case NPY_FLOAT16:
+ CopyNumpyArrayToLiteral<half>(py_array, literal);
+ break;
+ case NPY_FLOAT32:
+ CopyNumpyArrayToLiteral<float>(py_array, literal);
+ break;
+ case NPY_FLOAT64:
+ CopyNumpyArrayToLiteral<double>(py_array, literal);
+ break;
+ default:
+ LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
+ }
+}
+
+void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+ PyArrayObject* py_array) {
+ switch (np_type) {
+ case NPY_BOOL:
+ CopyLiteralToNumpyArray<bool>(literal, py_array);
+ break;
+ case NPY_INT32:
+ CopyLiteralToNumpyArray<int32>(literal, py_array);
+ break;
+ case NPY_INT64:
+ CopyLiteralToNumpyArray<int64>(literal, py_array);
+ break;
+ case NPY_UINT8:
+ CopyLiteralToNumpyArray<uint8>(literal, py_array);
+ break;
+ case NPY_UINT32:
+ CopyLiteralToNumpyArray<uint32>(literal, py_array);
+ break;
+ case NPY_UINT64:
+ CopyLiteralToNumpyArray<uint64>(literal, py_array);
+ break;
+ case NPY_FLOAT16:
+ CopyLiteralToNumpyArray<half>(literal, py_array);
+ break;
+ case NPY_FLOAT32:
+ CopyLiteralToNumpyArray<float>(literal, py_array);
+ break;
+ case NPY_FLOAT64:
+ CopyLiteralToNumpyArray<double>(literal, py_array);
+ break;
+ default:
+ LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
+ }
+}
+
+PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
+#if PY_MAJOR_VERSION < 3
+ return PyInt_FromLong(x);
+#else
+ return PyLong_FromLong(x);
+#endif
+}
+
+long PyIntOrPyLongToLong(PyObject* o) { // NOLINT
+#if PY_MAJOR_VERSION < 3
+ return PyInt_AsLong(o);
+#else
+ return PyLong_AsLong(o);
+#endif
+}
+
+bool CheckPyIntOrLong(PyObject* o) {
+#if PY_MAJOR_VERSION < 3
+ return PyInt_Check(o);
+#else
+ if (!PyLong_Check(o)) {
+ return false;
+ }
+ int overflow = 0;
+ PyLong_AsLongAndOverflow(o, &overflow);
+ return (overflow == 0);
+#endif
+}
+
+PyObject* PyNumberToPyInt(PyObject* o) {
+#if PY_MAJOR_VERSION < 3
+ return PyNumber_Int(o);
+#else
+ return PyNumber_Long(o);
+#endif
+}
+
+} // namespace numpy
+
+} // namespace swig
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
new file mode 100644
index 0000000000..4e6ecbb0e8
--- /dev/null
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// These functions transform Python/Numpy data structures to XLA data
+// structures and vice versa, performing copies where
+// appropriate. Python tuples and Numpy ndarrays translate to XLA
+// tuples and XLA literals, respectively, and Numpy shape/dtype
+// information is translated to XLA shape information.
+
+#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
+#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
+
+#include <algorithm>
+#include <memory>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/python/lib/core/numpy.h"
+
+namespace xla {
+
+namespace swig {
+
+namespace numpy {
+
+// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
+// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
+// vice versa.
+int PrimitiveTypeToNumpyType(PrimitiveType primitive_type);
+PrimitiveType NumpyTypeToPrimitiveType(int np_type);
+
+// Determines whether an integer-encoded Numpy dtype is valid,
+// i.e. has a supported conversion to an XLA PrimitiveType.
+bool NumpyTypeIsValid(int np_type);
+
+// Converts XLA shape information into a Python pair of the form
+// (numpy dtype, dimensions). If the XLA shape represents a tuple,
+// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a
+// Python tuple of shape-description pairs, created
+// recursively. Otherwise, `dimensions` is a Python tuple-of-integers
+// providing the array dimensions.
+//
+// The return value is a new reference.
+PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
+
+// Returns the outcome of a best-effort check that the Python object
+// is a pair of the form (numpy dtype, dimensions), as produced by
+// PyShapeInfoFromXlaShape.
+bool CheckPyShapeInfo(PyObject* o);
+
+// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
+//
+// The return value is a new reference.
+Shape XlaShapeFromPyShapeInfo(PyObject* o);
+
+// Converts an XLA literal to a Python object, either a Numpy ndarray
+// or a nested Python tuple thereof.
+//
+// To avoid transferring ownership of the data buffers that underlie
+// PyArrays and XLA literals, this function makes deep copies of all
+// array data.
+//
+// The return value is a new reference.
+PyObject* PyObjectFromXlaLiteral(const Literal& literal);
+
+// Converts a Numpy ndarray or a nested Python tuple thereof to a
+// corresponding XLA literal.
+//
+// To avoid transferring ownership of the data buffers that underlie
+// PyArrays and XLA literals, this function makes deep copies of all
+// array data.
+std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o);
+
+// The following functions copy array data from the buffers underlying Numpy
+// ndarrays into those underlying XLA literals, and vice versa.
+
+void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
+ Literal* literal);
+
+void CopyLiteralToNumpyArray(int np_type, const Literal& literal,
+ PyArrayObject* py_array);
+
+template <typename NativeT>
+void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
+ NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
+ auto dest = literal->GetMutableArraySlice<NativeT>();
+ std::copy(source, source + PyArray_SIZE(py_array), dest.data());
+}
+
+template <typename NativeT>
+void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
+ NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
+ auto source = literal.GetArraySlice<NativeT>();
+ std::copy(source.begin(), source.end(), dest);
+}
+
+// Workarounds for Python 2 and 3 interop
+
+PyObject* LongToPyIntOrPyLong(long x); // NOLINT
+long PyIntOrPyLongToLong(PyObject* o); // NOLINT
+bool CheckPyIntOrLong(PyObject* o);
+PyObject* PyNumberToPyInt(PyObject* o);
+
+} // namespace numpy
+
+} // namespace swig
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
diff --git a/tensorflow/compiler/xla/python/xla.i b/tensorflow/compiler/xla/python/xla.i
new file mode 100644
index 0000000000..1c4021a558
--- /dev/null
+++ b/tensorflow/compiler/xla/python/xla.i
@@ -0,0 +1,18 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* XLA-wide SWIG wrapper */
+
+%include "tensorflow/compiler/xla/python/local_computation_builder.i"
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
new file mode 100644
index 0000000000..c75d54856d
--- /dev/null
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -0,0 +1,605 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An in-process, local XLA client in Python, supporting AOT compilation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python import pywrap_xla as c_api
+
+_UNARY_OPS = [
+ 'Not',
+ 'Abs',
+ 'Exp',
+ 'Floor',
+ 'Ceil',
+ 'Log',
+ 'Sign',
+ 'Cos',
+ 'Sin',
+ 'Tanh',
+ 'SqrtF32',
+ 'SquareF32',
+ 'IsFinite',
+ 'ReciprocalF32',
+ 'Neg',
+ 'Sort',
+]
+
+_BINARY_OPS = [
+ 'Eq',
+ 'Ne',
+ 'Ge',
+ 'Gt',
+ 'Lt',
+ 'Le',
+ 'Add',
+ 'Sub',
+ 'Mul',
+ 'Div',
+ 'Rem',
+ 'Max',
+ 'Min',
+ 'And',
+ 'Or',
+ 'Pow',
+]
+
+# Most functions are snake_case for consistency with other modules,
+# whereas method names of ComputationBuilder and LocalComputation are
+# CamelCase for consistency with XLA.
+# pylint: disable=invalid-name
+
+XLA_ELEMENT_TYPE_TO_DTYPE = {
+ xla_data_pb2.F32: np.dtype(np.float32),
+ xla_data_pb2.F64: np.dtype(np.float64),
+ xla_data_pb2.S32: np.dtype(np.int32),
+ xla_data_pb2.S64: np.dtype(np.int64),
+ xla_data_pb2.PRED: np.dtype(np.bool),
+ xla_data_pb2.TUPLE: np.dtype(np.object),
+}
+
+DTYPE_TO_XLA_ELEMENT_TYPE = {
+ str(v): k
+ for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items()
+}
+
+
+class Shape(object):
+ """XLA shape.
+
+ Represents an XLA shape by a corresponding Python/Numpy type and a
+ list of dimensions, which are themselves Shapes in case this one
+ represents an XLA tuple.
+ """
+
+ def __init__(self, np_dtype, dimensions):
+ self.np_dtype = np_dtype
+ self._dimensions = dimensions
+
+ def element_type(self):
+ return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
+
+ def is_tuple(self):
+ return self.element_type() == xla_data_pb2.TUPLE
+
+ def dimensions(self):
+ if self.is_tuple():
+ raise ValueError('Tuple shape has no dimensions')
+ return self._dimensions
+
+ def tuple_shapes(self):
+ if not self.is_tuple():
+ raise ValueError('Shape is not a tuple shape')
+ return self._dimensions
+
+ @staticmethod
+ def from_numpy(npval):
+
+ def convert(npval):
+ if isinstance(npval, tuple):
+ return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval))
+ else:
+ return Shape(npval.dtype, np.shape(npval))
+
+ return convert(require_numpy_array_layout(npval))
+
+
+def _wrap_shape(shape_info):
+ dtype, dims = shape_info
+ element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
+ if element_type == xla_data_pb2.TUPLE:
+ dims = [_wrap_shape(subshape_info) for subshape_info in dims]
+ return Shape(dtype, dims)
+
+
+def _unwrap_shape(shape):
+ if shape.is_tuple():
+ components = tuple(
+ _unwrap_shape(subshape) for subshape in shape.tuple_shapes())
+ else:
+ components = shape.dimensions()
+ return (shape.np_dtype, components)
+
+
+def _unwrap_shapes(shapes):
+ return [_unwrap_shape(shape) for shape in shapes]
+
+
+def _wrap_data_handle(handle):
+ cdh = xla_data_pb2.ComputationDataHandle()
+ cdh.handle = handle
+ return cdh
+
+
+def _unwrap_data_handle(handle_proto):
+ return handle_proto.handle
+
+
+def _unwrap_data_handles(handle_protos):
+ return [_unwrap_data_handle(cdh) for cdh in handle_protos]
+
+
+def require_numpy_array_layout(value):
+ if isinstance(value, tuple):
+ return tuple(require_numpy_array_layout(x) for x in value)
+ else:
+ return np.require(value, requirements=['C', 'A'])
+
+
+class LocalComputation(object):
+ """Python wrapper for a local XLA Computation.
+
+ A LocalComputation can be executed if it is compiled. Otherwise, it
+ can still be used as a Computation where required by the
+ ComputationBuilder methods.
+ """
+
+ def __init__(self, c_local_computation, is_compiled):
+ self.c_local_computation = c_local_computation
+ self.is_compiled = is_compiled
+
+ # Ensure a reference to C-based destructor for use in __del__.
+ if is_compiled:
+ self._delete = c_api.DeleteCompiledLocalComputation
+ else:
+ self._delete = c_api.DeleteLocalComputation
+
+ def Compile(self, argument_shapes=()):
+ if self.is_compiled:
+ raise ValueError('Attempt to compile a compiled local XLA computation.')
+ return LocalComputation(
+ self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)),
+ is_compiled=True)
+
+ def CompileWithExampleArguments(self, arguments=()):
+ return self.Compile(
+ argument_shapes=[Shape.from_numpy(arg) for arg in arguments])
+
+ def Execute(self, arguments=()):
+ if not self.is_compiled:
+ raise ValueError('Cannot execute an uncompiled local XLA computation.')
+ arguments = tuple(map(require_numpy_array_layout, arguments))
+ return self.c_local_computation.Execute(arguments)
+
+ def __del__(self):
+ self._delete(self.c_local_computation)
+
+
+class ComputationBuilder(object):
+ """XLA computation builder.
+
+ Enqueues XLA ops in sequence and in order to build a
+ LocalComputation, which in turn can be compiled into a
+ CompiledLocalComputation, which in turn can be locally executed.
+ """
+
+ # The methods of this class map 1-to-1 onto the XLA C++
+ # computation builder API. Therefore, there's no need to laboriously list
+ # arguments and return values for every method, especially where it's obvious.
+ #
+ # pylint: disable=g-doc-return-or-yield
+ # pylint: disable=g-doc-args
+
+ def __init__(self, name):
+ self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
+ self._parameter_numbering = itertools.count()
+
+ def Build(self):
+ return LocalComputation(self._client.Build(), is_compiled=False)
+
+ def Constant(self, value):
+ """Enqueues a constant op onto the computation.
+
+ Args:
+ value: value for the constant, as a np.array with an explicit dtype set
+ to one of the supported types.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ value = require_numpy_array_layout(value)
+ return _wrap_data_handle(self._client.ConstantLiteral(value))
+
+ def ConstantF32Scalar(self, value):
+ """Convenience method to enqueue a scalar F32 constant op.
+
+ Args:
+ value: a floating-point number.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ return self.Constant(np.array(value, dtype=np.float32))
+
+ def ConstantF64Scalar(self, value):
+ """Convenience method to enqueue a scalar F32 constant op.
+
+ Args:
+ value: a floating-point number.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ return self.Constant(np.array(value, dtype=np.float64))
+
+ def ConstantS32Scalar(self, value):
+ """Convenience method to enqueue a scalar S32 constant op.
+
+ Args:
+ value: a floating-point number.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ return self.Constant(np.array(value, dtype=np.int32))
+
+ def ConstantS64Scalar(self, value):
+ """Convenience method to enqueue a scalar S64 constant op.
+
+ Args:
+ value: a floating-point number.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ return self.Constant(np.array(value, dtype=np.int64))
+
+ def ConstantPredScalar(self, value):
+ """Convenience method to enqueue a scalar PRED constant op.
+
+ Args:
+ value: a boolean value.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ return self.Constant(np.array(value, dtype=np.bool))
+
+ def ParameterWithShape(self, shape, name=None, parameter_num=None):
+ """Enqueues a Parameter op onto the computation, given a shape.
+
+ Args:
+ shape: the parameter's shape as a Shape object.
+ name: optional string name for the parameter.
+ parameter_num: parameter number in the computation function. If None,
+ the next linear parameter number is used. The default value capability
+ can be used for auto-numbering. If you're using auto-numbering for some
+ parameters, use it for *all* parameters to avoid clashes.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ if name is None:
+ name = ''
+ if parameter_num is None:
+ parameter_num = next(self._parameter_numbering)
+
+ return _wrap_data_handle(
+ self._client.Parameter(
+ parameter_num, _unwrap_shape(shape), name.encode('utf8')))
+
+ def ParameterFromNumpy(self, value, name=None, parameter_num=None):
+ """Enqueues a Parameter op onto the computation.
+
+ Args:
+ value: a Numpy array, or a nested tuple thereof, from which the
+ shape is inferred.
+ name: as in ParameterWithShape.
+ parameter_num: as in ParameterWithShape.
+
+ Returns:
+ A ComputationDataHandle message.
+ """
+ return self.ParameterWithShape(
+ Shape.from_numpy(value), name=name, parameter_num=parameter_num)
+
+ def Broadcast(self, operand, sizes):
+ """Enqueues a broadcast operation onto the computation.
+
+ Args:
+ operand: the operand ComputationDataHandle to broadcast.
+ sizes: an iterable of broadcast sizes.
+
+ Returns:
+ A ComputationDataHandle representing the added broadcast op.
+ """
+ return _wrap_data_handle(
+ self._client.Broadcast(_unwrap_data_handle(operand), sizes))
+
+ def Concatenate(self, operands, dimension):
+ """Enqueues a concatenate operation onto the computation.
+
+ Args:
+ operands: the operands to concatenate.
+ dimension: the dimension in which to perform the concatenation.
+
+ Returns:
+ A ComputationDataHandle representing the added concatenate op.
+ """
+ return _wrap_data_handle(
+ self._client.ConcatInDim(_unwrap_data_handles(operands), dimension))
+
+ def ConvertElementType(self, operand, new_element_type):
+ """Enqueues an element type conversion operation onto the computation.
+
+ Args:
+ operand: the operand to convert.
+ new_element_type: the target primitive type.
+
+ Returns:
+ A ComputationDataHandle representing the added conversion op.
+ """
+ return _wrap_data_handle(
+ self._client.ConvertElementType(
+ _unwrap_data_handle(operand), new_element_type))
+
+ def GetShape(self, operand):
+ return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
+
+ def GetComputationStats(self):
+ raise NotImplementedError()
+
+ def Reshape(self, operand, dimensions, new_sizes):
+ """Reshape op."""
+ return _wrap_data_handle(
+ self._client.Reshape(
+ _unwrap_data_handle(operand), dimensions, new_sizes))
+
+ def Trans(self, operand):
+ """Specialized matrix transpose op."""
+ return _wrap_data_handle(
+ self._client.Transpose(_unwrap_data_handle(operand), [1, 0]))
+
+ def Transpose(self, operand, permutation):
+ """Transpose op."""
+ return _wrap_data_handle(
+ self._client.Transpose(_unwrap_data_handle(operand), permutation))
+
+ def Select(self, pred, on_true, on_false):
+ """Element-wise selection op.
+
+ Constructs an output array from elements of two input arrays, based on the
+ values of a predicate array.
+ """
+ return _wrap_data_handle(
+ self._client.Select(
+ _unwrap_data_handle(pred),
+ _unwrap_data_handle(on_true),
+ _unwrap_data_handle(on_false)))
+
+ def Slice(self, operand, start_indices, limit_indices, strides=None):
+ """Enqueues a slice operation onto the computation.
+
+ Args:
+ operand: ComputationDataHandle for the N dimensional array to be sliced.
+ start_indices: iterable of N integers containing the starting indices of
+ the slice for each dimension.
+ limit_indices: iterable of N integers containing the ending indices
+ (exclusive) of the slice for each dimension.
+ strides: optional iterable of N integers containing the stride sizes for
+ each dimension.
+
+ Returns:
+ A ComputationDataHandle representing the added Slice op.
+ """
+ if strides is None:
+ start_indices = list(start_indices)
+ strides = [1] * len(start_indices)
+ return _wrap_data_handle(
+ self._client.Slice(
+ _unwrap_data_handle(operand),
+ start_indices,
+ limit_indices,
+ strides))
+
+ def DynamicSlice(self, operand, start_indices, slice_sizes):
+ """Enqueues a slice op with dynamic start indices onto the computation.
+
+ Args:
+ operand: ComputationDataHandle for the N dimensional array to be sliced.
+ start_indices: ComputationDataHandle for the 1D array of N integers
+ containing the starting indices of the slice.
+ slice_sizes: iterable of N integers containing the slice sizes in each
+ dimension.
+
+ Returns:
+ A ComputationDataHandle representing the added DynamicSlice op.
+ """
+ return _wrap_data_handle(
+ self._client.DynamicSlice(
+ _unwrap_data_handle(operand),
+ _unwrap_data_handle(start_indices),
+ slice_sizes))
+
+ def DynamicUpdateSlice(self, operand, update, start_indices):
+ """Enqueues a dynamic update slice operation onto the computation.
+
+ Args:
+ operand: ComputationDataHandle for the N dimensional array to be updated.
+ update: N dimensional array comprising the slice update.
+ start_indices: Rank-1 array of N integers comprising the starting indices
+ of the slice along each dimension.
+ Returns:
+ A ComputationDataHandle representing the added DynamicUpdateSlice op.
+ """
+ return _wrap_data_handle(
+ self._client.DynamicUpdateSlice(
+ _unwrap_data_handle(operand),
+ _unwrap_data_handle(update),
+ _unwrap_data_handle(start_indices)))
+
+ def Tuple(self, *ops):
+ """Enqueues a tuple operation onto the computation.
+
+ Args:
+ ops: a sequence of tuple operands (each a ComputationDataHandle).
+
+ Returns:
+ A ComputationDataHandle representing the added Tuple op.
+ """
+ return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops)))
+
+ def GetTupleElement(self, tup, index):
+ """Enqueues a 'get tuple element' operation onto the computation.
+
+ Args:
+ tup: the tuple operand (a ComputationDataHandle).
+ index: numeric index to select from the tuple.
+
+ Returns:
+ A ComputationDataHandle representing the added GetTupleElement op.
+ """
+ return _wrap_data_handle(
+ self._client.GetTupleElement(_unwrap_data_handle(tup), index))
+
+ def Call(self, computation_to_apply, operands):
+ """Enqueues a call operation onto the computation.
+
+ Args:
+ computation_to_apply: a Computation object.
+ operands: an iterable of ComputationDataHandle. The number and types of
+ operands must match the arity of computation_to_apply.
+
+ Returns:
+ A ComputationDataHandle representing the added call op.
+ """
+ return _wrap_data_handle(
+ self._client.Call(computation_to_apply.c_local_computation,
+ _unwrap_data_handles(operands)))
+
+ def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
+ """Enqueues a map operation onto the computation.
+
+ Args:
+ operands: an iterable of ComputationDataHandle.
+ computation_to_apply: a Computation object.
+ dimensions: dimensions over which to apply map the function.
+ static_operands: auxiliary arguments passed to the applied computation.
+
+ Returns:
+ A ComputationDataHandle representing the added Map op.
+ """
+ return _wrap_data_handle(
+ self._client.Map(
+ _unwrap_data_handles(operands),
+ computation_to_apply.c_local_computation,
+ dimensions,
+ _unwrap_data_handles(static_operands)))
+
+ def Reduce(self, operand, init_value, computation_to_apply, dimensions):
+ """Enqueues a reduction operation onto the computation.
+
+ Args:
+ operand: reduction operand (ComputationDataHandle).
+ init_value: reduction initial value (ComputationDataHandle).
+ computation_to_apply: a Computation object - binary reduction function.
+ dimensions: sequence of dimensions (integers) to reduce on.
+
+ Returns:
+ A ComputationDataHandle representing the added Reduce op.
+ """
+ return _wrap_data_handle(
+ self._client.Reduce(
+ _unwrap_data_handle(operand),
+ _unwrap_data_handle(init_value),
+ computation_to_apply.c_local_computation,
+ dimensions))
+
+ def While(self, cond, body, init):
+ """Enqueues a While operation onto the computation.
+
+ Args:
+ cond: a Computation for the loop condition, which has type T -> PRED
+ body: a Computation for the loop body, which has type T -> T
+ init: an ComputationDataHandle for the initial parameter, which has type T
+
+ Returns: a ComputationDataHandle representing the While operation.
+ """
+ return _wrap_data_handle(
+ self._client.While(cond.c_local_computation,
+ body.c_local_computation,
+ _unwrap_data_handle(init)))
+
+ def Dot(self, lhs, rhs):
+ """Matrix multiplication between lhs and rhs."""
+ return _wrap_data_handle(
+ self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
+
+
+def _forward_methods_to_local_builder():
+ """Forward remaining ComputationBuilder methods to the C API.
+
+ Set up methods, corresponding to unary and binary XLA operations,
+ whose calls are forwarded in a boilerplate manner to the underlying
+ LocalComputationBuilder C-extension API.
+ """
+
+ def forward_to_local_builder_with_handles(target_method, is_binop=False):
+ """Generate a forwarding method that wraps/unwraps data handles."""
+
+ def forward(self, *args, **kwargs):
+ unwrapped_args = [_unwrap_data_handle(arg) for arg in args]
+
+ if is_binop and len(unwrapped_args) < 3:
+ unwrapped_args.append(kwargs.get('broadcast_dimensions', ()))
+
+ return _wrap_data_handle(
+ target_method(
+ self._client, # pylint: disable=protected-access
+ *unwrapped_args))
+
+ return forward
+
+ for method_name in _UNARY_OPS:
+ forward = forward_to_local_builder_with_handles(
+ getattr(c_api.LocalComputationBuilder, method_name))
+ forward.__name__ = method_name
+ setattr(ComputationBuilder, method_name, forward)
+
+ for method_name in _BINARY_OPS:
+ forward = forward_to_local_builder_with_handles(
+ getattr(c_api.LocalComputationBuilder, method_name), is_binop=True)
+ forward.__name__ = method_name
+ setattr(ComputationBuilder, method_name, forward)
+
+
+_forward_methods_to_local_builder()
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
new file mode 100644
index 0000000000..878cd83edc
--- /dev/null
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -0,0 +1,898 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the Python extension-based XLA client."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.compiler.xla.python import xla_client
+import unittest
+
+
+class LocalComputationTest(unittest.TestCase):
+ """Base class for running an XLA Computation through the local client."""
+
+ def _NewComputation(self, name=None):
+ if name is None:
+ name = self.id()
+ return xla_client.ComputationBuilder(name)
+
+ def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
+ assert expected is not None
+ compiled_c = c.Build().CompileWithExampleArguments(arguments)
+ result = compiled_c.Execute(arguments)
+ # Numpy's comparison methods are a bit too lenient by treating inputs as
+ # "array-like", meaning that scalar 4 will be happily compared equal to
+ # [[4]]. We'd like to be more strict so assert shapes as well.
+ self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
+ assert_func(result, expected)
+
+ def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
+ self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
+
+ def _ExecuteAndCompareClose(self, c, arguments=(), expected=None):
+ self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments,
+ expected)
+
+
+def NumpyArrayF32(*args, **kwargs):
+ """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
+ return np.array(*args, dtype=np.float32, **kwargs)
+
+
+def NumpyArrayF64(*args, **kwargs):
+ """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
+ return np.array(*args, dtype=np.float64, **kwargs)
+
+
+def NumpyArrayS32(*args, **kwargs):
+ """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
+ return np.array(*args, dtype=np.int32, **kwargs)
+
+
+def NumpyArrayS64(*args, **kwargs):
+ """Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
+ return np.array(*args, dtype=np.int64, **kwargs)
+
+
+def NumpyArrayBool(*args, **kwargs):
+ """Convenience wrapper to create Numpy arrays with a np.bool dtype."""
+ return np.array(*args, dtype=np.bool, **kwargs)
+
+
+class ComputationsWithConstantsTest(LocalComputationTest):
+ """Tests focusing on Constant ops."""
+
+ def testConstantScalarSumF32(self):
+ c = self._NewComputation()
+ c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
+ self._ExecuteAndCompareClose(c, expected=4.25)
+
+ def testConstantScalarSumF64(self):
+ c = self._NewComputation()
+ c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
+ self._ExecuteAndCompareClose(c, expected=4.25)
+
+ def testConstantScalarSumS32(self):
+ c = self._NewComputation()
+ c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
+ self._ExecuteAndCompareClose(c, expected=3)
+
+ def testConstantScalarSumS64(self):
+ c = self._NewComputation()
+ c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
+ self._ExecuteAndCompareClose(c, expected=3)
+
+ def testConstantVectorMulF32(self):
+ c = self._NewComputation()
+ c.Mul(
+ c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
+ c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
+ self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
+
+ def testConstantVectorMulF64(self):
+ c = self._NewComputation()
+ c.Mul(
+ c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
+ c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
+ self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
+
+ def testConstantVectorScalarDivF32(self):
+ c = self._NewComputation()
+ c.Div(
+ c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
+ c.ConstantF32Scalar(2.0))
+ self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
+
+ def testConstantVectorScalarDivF64(self):
+ c = self._NewComputation()
+ c.Div(
+ c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
+ c.ConstantF64Scalar(2.0))
+ self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
+
+ def testConstantVectorScalarPowF32(self):
+ c = self._NewComputation()
+ c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
+ self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
+
+ def testConstantVectorScalarPowF64(self):
+ c = self._NewComputation()
+ c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
+ self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
+
+ def testBooleanAnd(self):
+ c = self._NewComputation()
+ c.And(
+ c.Constant(NumpyArrayBool([True, False, True, False])),
+ c.Constant(NumpyArrayBool([True, True, False, False])))
+ self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
+
+ def testBooleanOr(self):
+ c = self._NewComputation()
+ c.Or(
+ c.Constant(NumpyArrayBool([True, False, True, False])),
+ c.Constant(NumpyArrayBool([True, True, False, False])))
+ self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
+
+ def testSum2DF32(self):
+ c = self._NewComputation()
+ c.Add(
+ c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
+ c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
+ self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
+
+ def testSum2DF64(self):
+ c = self._NewComputation()
+ c.Add(
+ c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
+ c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
+ self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
+
+ def testSum2DWith1DBroadcastDim0F32(self):
+ # sum of a 2D array with a 1D array where the latter is replicated across
+ # dimension 0 to match the former's shape.
+ c = self._NewComputation()
+ c.Add(
+ c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayF32([10, 20, 30])),
+ broadcast_dimensions=(0,))
+ self._ExecuteAndCompareClose(
+ c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
+
+ def testSum2DWith1DBroadcastDim0F64(self):
+ # sum of a 2D array with a 1D array where the latter is replicated across
+ # dimension 0 to match the former's shape.
+ c = self._NewComputation()
+ c.Add(
+ c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayF64([10, 20, 30])),
+ broadcast_dimensions=(0,))
+ self._ExecuteAndCompareClose(
+ c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
+
+ def testSum2DWith1DBroadcastDim1F32(self):
+ # sum of a 2D array with a 1D array where the latter is replicated across
+ # dimension 1 to match the former's shape.
+ c = self._NewComputation()
+ c.Add(
+ c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayF32([10, 20, 30])),
+ broadcast_dimensions=(1,))
+ self._ExecuteAndCompareClose(
+ c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
+
+ def testSum2DWith1DBroadcastDim1F64(self):
+ # sum of a 2D array with a 1D array where the latter is replicated across
+ # dimension 1 to match the former's shape.
+ c = self._NewComputation()
+ c.Add(
+ c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayF64([10, 20, 30])),
+ broadcast_dimensions=(1,))
+ self._ExecuteAndCompareClose(
+ c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
+
+ def testConstantAxpyF32(self):
+ c = self._NewComputation()
+ c.Add(
+ c.Mul(
+ c.ConstantF32Scalar(2),
+ c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
+ c.Constant(NumpyArrayF32([100, -100, 200, -200])))
+ self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
+
+ def testConstantAxpyF64(self):
+ c = self._NewComputation()
+ c.Add(
+ c.Mul(
+ c.ConstantF64Scalar(2),
+ c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
+ c.Constant(NumpyArrayF64([100, -100, 200, -200])))
+ self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
+
+
+class ParametersTest(LocalComputationTest):
+ """Tests focusing on Parameter ops and argument-passing."""
+
+ def setUp(self):
+ self.f32_scalar_2 = NumpyArrayF32(2.0)
+ self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
+ self.f64_scalar_2 = NumpyArrayF64(2.0)
+ self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
+ self.s32_scalar_3 = NumpyArrayS32(3)
+ self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
+ self.s64_scalar_3 = NumpyArrayS64(3)
+ self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
+
+ def testScalarTimesVectorAutonumberF32(self):
+ c = self._NewComputation()
+ p0 = c.ParameterFromNumpy(self.f32_scalar_2)
+ p1 = c.ParameterFromNumpy(self.f32_4vector)
+ c.Mul(p0, p1)
+ self._ExecuteAndCompareClose(
+ c,
+ arguments=[self.f32_scalar_2, self.f32_4vector],
+ expected=[-4.6, 6.6, -8.6, 10.6])
+
+ def testScalarTimesVectorAutonumberF64(self):
+ c = self._NewComputation()
+ p0 = c.ParameterFromNumpy(self.f64_scalar_2)
+ p1 = c.ParameterFromNumpy(self.f64_4vector)
+ c.Mul(p0, p1)
+ self._ExecuteAndCompareClose(
+ c,
+ arguments=[self.f64_scalar_2, self.f64_4vector],
+ expected=[-4.6, 6.6, -8.6, 10.6])
+
+ def testScalarTimesVectorS32(self):
+ c = self._NewComputation()
+ p0 = c.ParameterFromNumpy(self.s32_scalar_3)
+ p1 = c.ParameterFromNumpy(self.s32_4vector)
+ c.Mul(p0, p1)
+ self._ExecuteAndCompareExact(
+ c,
+ arguments=[self.s32_scalar_3, self.s32_4vector],
+ expected=[30, 45, -6, 21])
+
+ def testScalarTimesVectorS64(self):
+ c = self._NewComputation()
+ p0 = c.ParameterFromNumpy(self.s64_scalar_3)
+ p1 = c.ParameterFromNumpy(self.s64_4vector)
+ c.Mul(p0, p1)
+ self._ExecuteAndCompareExact(
+ c,
+ arguments=[self.s64_scalar_3, self.s64_4vector],
+ expected=[30, 45, -6, 21])
+
+ def testScalarMinusVectorExplicitNumberingF32(self):
+ # Use explicit numbering and pass parameter_num first. Sub is used since
+ # it's not commutative and can help catch parameter reversal within the
+ # computation.
+ c = self._NewComputation()
+ p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
+ p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
+ c.Sub(p1, p0)
+ self._ExecuteAndCompareClose(
+ c,
+ arguments=[self.f32_scalar_2, self.f32_4vector],
+ expected=[-4.3, 1.3, -6.3, 3.3])
+
+ def testScalarMinusVectorExplicitNumberingF64(self):
+ # Use explicit numbering and pass parameter_num first. Sub is used since
+ # it's not commutative and can help catch parameter reversal within the
+ # computation.
+ c = self._NewComputation()
+ p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
+ p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
+ c.Sub(p1, p0)
+ self._ExecuteAndCompareClose(
+ c,
+ arguments=[self.f64_scalar_2, self.f64_4vector],
+ expected=[-4.3, 1.3, -6.3, 3.3])
+
+
+class SingleOpTest(LocalComputationTest):
+ """Tests for single ops.
+
+ The goal here is smoke testing - to exercise the most basic functionality of
+ single XLA ops. As minimal as possible number of additional ops are added
+ around the op being tested.
+ """
+
+ def testConcatenateF32(self):
+ c = self._NewComputation()
+ c.Concatenate(
+ (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
+ c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
+ dimension=0)
+ self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+
+ def testConcatenateF64(self):
+ c = self._NewComputation()
+ c.Concatenate(
+ (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
+ c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
+ dimension=0)
+ self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+
+ def testConvertElementType(self):
+ xla_types = {
+ np.bool: xla_client.xla_data_pb2.PRED,
+ np.int32: xla_client.xla_data_pb2.S32,
+ np.int64: xla_client.xla_data_pb2.S64,
+ np.float32: xla_client.xla_data_pb2.F32,
+ np.float64: xla_client.xla_data_pb2.F64,
+ }
+
+ def _ConvertAndTest(template, src_dtype, dst_dtype):
+ c = self._NewComputation()
+ x = c.Constant(np.array(template, dtype=src_dtype))
+ c.ConvertElementType(x, xla_types[dst_dtype])
+
+ result = c.Build().Compile().Execute()
+ expected = np.array(template, dtype=dst_dtype)
+
+ self.assertEqual(result.shape, expected.shape)
+ self.assertEqual(result.dtype, expected.dtype)
+ np.testing.assert_equal(result, expected)
+
+ x = [0, 1, 0, 0, 1]
+ for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
+ _ConvertAndTest(x, src_dtype, dst_dtype)
+
+ def testDotMatrixVectorF32(self):
+ c = self._NewComputation()
+ lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
+ rhs = NumpyArrayF32([[10.0], [20.0]])
+ c.Dot(c.Constant(lhs), c.Constant(rhs))
+ self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
+
+ def testDotMatrixVectorF64(self):
+ c = self._NewComputation()
+ lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
+ rhs = NumpyArrayF64([[10.0], [20.0]])
+ c.Dot(c.Constant(lhs), c.Constant(rhs))
+ self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
+
+ def testDotMatrixMatrixF32(self):
+ c = self._NewComputation()
+ lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
+ rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
+ c.Dot(c.Constant(lhs), c.Constant(rhs))
+ self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
+
+ def testDotMatrixMatrixF64(self):
+ c = self._NewComputation()
+ lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
+ rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
+ c.Dot(c.Constant(lhs), c.Constant(rhs))
+ self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
+
+ def testBooleanNot(self):
+ c = self._NewComputation()
+ arr = NumpyArrayBool([True, False, True])
+ c.Not(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=~arr)
+
+ def testExp(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Exp(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.exp(arr))
+
+ def testLog(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Log(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.log(arr))
+
+ def testNeg(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Neg(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=-arr)
+
+ def testFloor(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Floor(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.floor(arr))
+
+ def testCeil(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Ceil(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
+
+ def testAbs(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
+ c.Abs(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.abs(arr))
+
+ def testTanh(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Tanh(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
+
+ def testTrans(self):
+
+ def _TransposeAndTest(array):
+ c = self._NewComputation()
+ c.Trans(c.Constant(array))
+ self._ExecuteAndCompareClose(c, expected=array.T)
+
+ # Test square and non-square matrices in both default (C) and F orders.
+ for array_fun in [NumpyArrayF32, NumpyArrayF64]:
+ _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
+ _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
+ _TransposeAndTest(array_fun([[1, 2], [4, 5]]))
+ _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
+
+ def testTranspose(self):
+
+ def _TransposeAndTest(array, permutation):
+ c = self._NewComputation()
+ c.Transpose(c.Constant(array), permutation)
+ expected = np.transpose(array, permutation)
+ self._ExecuteAndCompareClose(c, expected=expected)
+
+ _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
+ _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
+ _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
+ _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
+
+ arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
+ for permutation in itertools.permutations(range(arr.ndim)):
+ _TransposeAndTest(arr, permutation)
+ _TransposeAndTest(np.asfortranarray(arr), permutation)
+
+ def testEq(self):
+ c = self._NewComputation()
+ c.Eq(
+ c.Constant(NumpyArrayS32([1, 2, 3, 4])),
+ c.Constant(NumpyArrayS32([4, 2, 3, 1])))
+ self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
+
+ def testNe(self):
+ c = self._NewComputation()
+ c.Ne(
+ c.Constant(NumpyArrayS32([1, 2, 3, 4])),
+ c.Constant(NumpyArrayS32([4, 2, 3, 1])))
+ self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
+
+ c.Ne(
+ c.Constant(NumpyArrayF32([-2.0, 0.0,
+ float("nan"),
+ float("nan")])),
+ c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
+ self._ExecuteAndAssertWith(
+ np.testing.assert_allclose, c, (), expected=[True, False, True, True])
+
+ def testGt(self):
+ c = self._NewComputation()
+ c.Gt(
+ c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
+ c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
+ self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
+
+ def testGe(self):
+ c = self._NewComputation()
+ c.Ge(
+ c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
+ c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
+ self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
+
+ def testLt(self):
+ c = self._NewComputation()
+ c.Lt(
+ c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
+ c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
+ self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
+
+ def testLe(self):
+ c = self._NewComputation()
+ c.Le(
+ c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
+ c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
+ self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
+
+ def testMax(self):
+ c = self._NewComputation()
+ c.Max(
+ c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
+ c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
+ self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
+
+ def testMaxExplicitBroadcastDim0(self):
+ c = self._NewComputation()
+ c.Max(
+ c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayF32([3, 4, 5])),
+ broadcast_dimensions=(0,))
+ self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
+
+ def testMaxExplicitBroadcastDim1(self):
+ c = self._NewComputation()
+ c.Max(
+ c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayF32([3, 4, 5])),
+ broadcast_dimensions=(1,))
+ self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
+
+ def testMin(self):
+ c = self._NewComputation()
+ c.Min(
+ c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
+ c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
+ self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
+
+ def testReshape(self):
+ c = self._NewComputation()
+ c.Reshape(
+ c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
+ dimensions=[0, 1],
+ new_sizes=[2, 3])
+ self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
+
+ def testSelect(self):
+ c = self._NewComputation()
+ c.Select(
+ c.Constant(NumpyArrayBool([True, False, False, True, False])),
+ c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
+ c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
+ self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
+
+ def testSlice(self):
+ c = self._NewComputation()
+ c.Slice(
+ c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
+ [3, 2])
+ self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
+
+ def testDynamicSlice(self):
+ c = self._NewComputation()
+ c.DynamicSlice(
+ c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayS32([1, 0])), [2, 2])
+ self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
+
+ def testDynamicUpdateSlice(self):
+ c = self._NewComputation()
+ c.DynamicUpdateSlice(
+ c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
+ c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
+ c.Constant(NumpyArrayS32([1, 1])))
+ self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
+
+ def testTuple(self):
+ c = self._NewComputation()
+ c.Tuple(
+ c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
+ c.Constant(NumpyArrayBool([True, False, False, True])))
+ result = c.Build().Compile().Execute()
+ self.assertIsInstance(result, tuple)
+ np.testing.assert_equal(result[0], 42)
+ np.testing.assert_allclose(result[1], [1.0, 2.0])
+ np.testing.assert_equal(result[2], [True, False, False, True])
+
+ def testGetTupleElement(self):
+ c = self._NewComputation()
+ c.GetTupleElement(
+ c.Tuple(
+ c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
+ c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
+ self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
+
+ def testBroadcast(self):
+ c = self._NewComputation()
+ c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
+ self._ExecuteAndCompareExact(
+ c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
+
+
+class EmbeddedComputationsTest(LocalComputationTest):
+ """Tests for XLA graphs with embedded computations (such as maps)."""
+
+ def _CreateConstantS32Computation(self):
+ """Computation (f32) -> s32 that returns a constant 1 for any input."""
+ c = self._NewComputation("constant_s32_one")
+ # TODO(eliben): consider adding a nicer way to create new parameters without
+ # having to create dummy Numpy arrays or populating Shape messages. Perhaps
+ # we need our own (Python-client-own) way to represent Shapes conveniently.
+ c.ParameterFromNumpy(NumpyArrayF32(0))
+ c.ConstantS32Scalar(1)
+ return c.Build()
+
+ def _CreateConstantS64Computation(self):
+ """Computation (f64) -> s64 that returns a constant 1 for any input."""
+ c = self._NewComputation("constant_s64_one")
+ # TODO(eliben): consider adding a nicer way to create new parameters without
+ # having to create dummy Numpy arrays or populating Shape messages. Perhaps
+ # we need our own (Python-client-own) way to represent Shapes conveniently.
+ c.ParameterFromNumpy(NumpyArrayF64(0))
+ c.ConstantS64Scalar(1)
+ return c.Build()
+
+ def _CreateConstantF32Computation(self):
+ """Computation (f32) -> f32 that returns a constant 1.0 for any input."""
+ c = self._NewComputation("constant_f32_one")
+ c.ParameterFromNumpy(NumpyArrayF32(0))
+ c.ConstantF32Scalar(1.0)
+ return c.Build()
+
+ def _CreateConstantF64Computation(self):
+ """Computation (f64) -> f64 that returns a constant 1.0 for any input."""
+ c = self._NewComputation("constant_f64_one")
+ c.ParameterFromNumpy(NumpyArrayF64(0))
+ c.ConstantF64Scalar(1.0)
+ return c.Build()
+
+ def _CreateMulF32By2Computation(self):
+ """Computation (f32) -> f32 that multiplies its parameter by 2."""
+ c = self._NewComputation("mul_f32_by2")
+ c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
+ return c.Build()
+
+ def _CreateMulF64By2Computation(self):
+ """Computation (f64) -> f64 that multiplies its parameter by 2."""
+ c = self._NewComputation("mul_f64_by2")
+ c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
+ return c.Build()
+
+ def _CreateBinaryAddF32Computation(self):
+ """Computation (f32, f32) -> f32 that adds its two parameters."""
+ c = self._NewComputation("add_param0_by_param1")
+ c.Add(
+ c.ParameterFromNumpy(NumpyArrayF32(0)),
+ c.ParameterFromNumpy(NumpyArrayF32(0)))
+ return c.Build()
+
+ def _CreateBinaryAddF64Computation(self):
+ """Computation (f64, f64) -> f64 that adds its two parameters."""
+ c = self._NewComputation("add_param0_by_param1")
+ c.Add(
+ c.ParameterFromNumpy(NumpyArrayF64(0)),
+ c.ParameterFromNumpy(NumpyArrayF64(0)))
+ return c.Build()
+
+ def _CreateBinaryDivF32Computation(self):
+ """Computation (f32, f32) -> f32 that divides its two parameters."""
+ c = self._NewComputation("div_param0_by_param1")
+ c.Div(
+ c.ParameterFromNumpy(NumpyArrayF32(0)),
+ c.ParameterFromNumpy(NumpyArrayF32(0)))
+ return c.Build()
+
+ def _CreateBinaryDivF64Computation(self):
+ """Computation (f64, f64) -> f64 that divides its two parameters."""
+ c = self._NewComputation("div_param0_by_param1")
+ c.Div(
+ c.ParameterFromNumpy(NumpyArrayF64(0)),
+ c.ParameterFromNumpy(NumpyArrayF64(0)))
+ return c.Build()
+
+ def _CreateTestF32Lt10Computation(self):
+ """Computation (f32) -> bool that tests if its parameter is less than 10."""
+ c = self._NewComputation("test_f32_lt_10")
+ c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
+ return c.Build()
+
+ def _CreateTestF64Lt10Computation(self):
+ """Computation (f64) -> bool that tests if its parameter is less than 10."""
+ c = self._NewComputation("test_f64_lt_10")
+ c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
+ return c.Build()
+
+ def _MakeSample3DArrayF32(self):
+ return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
+ [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
+
+ def _MakeSample3DArrayF64(self):
+ return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
+ [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
+
+ def testCallF32(self):
+ c = self._NewComputation()
+ c.Call(
+ self._CreateMulF32By2Computation(),
+ operands=(c.ConstantF32Scalar(5.0),))
+ self._ExecuteAndCompareClose(c, expected=10.0)
+
+ def testCallF64(self):
+ c = self._NewComputation()
+ c.Call(
+ self._CreateMulF64By2Computation(),
+ operands=(c.ConstantF64Scalar(5.0),))
+ self._ExecuteAndCompareClose(c, expected=10.0)
+
+ def testMapEachElementToS32Constant(self):
+ c = self._NewComputation()
+ c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
+ self._CreateConstantS32Computation(), [0])
+ self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
+
+ def testMapEachElementToS64Constant(self):
+ c = self._NewComputation()
+ c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
+ self._CreateConstantS64Computation(), [0])
+ self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
+
+ def testMapMulBy2F32(self):
+ c = self._NewComputation()
+ c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
+ self._CreateMulF32By2Computation(), [0])
+ self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
+
+ def testMapMulBy2F64(self):
+ c = self._NewComputation()
+ c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
+ self._CreateMulF64By2Computation(), [0])
+ self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
+
+ def testSimpleMapChainF32(self):
+ # Chains a map of constant-f32 with a map of mul-by-2
+ c = self._NewComputation()
+ const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
+ self._CreateConstantF32Computation(), [0])
+ c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
+ self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
+
+ def testSimpleMapChainF64(self):
+ # Chains a map of constant-f64 with a map of mul-by-2
+ c = self._NewComputation()
+ const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
+ self._CreateConstantF64Computation(), [0])
+ c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
+ self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
+
+ def testDivVectorsWithMapF32(self):
+ c = self._NewComputation()
+ c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
+ c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
+ self._CreateBinaryDivF32Computation(), [0])
+ self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
+
+ def testDivVectorsWithMapF64(self):
+ c = self._NewComputation()
+ c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
+ c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
+ self._CreateBinaryDivF64Computation(), [0])
+ self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
+
+ def testReduce1DtoScalarF32(self):
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
+ init_value=c.ConstantF32Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF32Computation(),
+ dimensions=[0])
+ self._ExecuteAndCompareClose(c, expected=10)
+
+ def testReduce1DtoScalarF64(self):
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
+ init_value=c.ConstantF64Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF64Computation(),
+ dimensions=[0])
+ self._ExecuteAndCompareClose(c, expected=10)
+
+ def testReduce2DTo1DDim0F32(self):
+ input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(input_array),
+ init_value=c.ConstantF32Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF32Computation(),
+ dimensions=[0])
+ self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
+
+ def testReduce2DTo1DDim0F64(self):
+ input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(input_array),
+ init_value=c.ConstantF64Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF64Computation(),
+ dimensions=[0])
+ self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
+
+ def testReduce2DTo1DDim1F32(self):
+ input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(input_array),
+ init_value=c.ConstantF32Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF32Computation(),
+ dimensions=[1])
+ self._ExecuteAndCompareClose(c, expected=[6, 15])
+
+ def testReduce2DTo1DDim1F64(self):
+ input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(input_array),
+ init_value=c.ConstantF64Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF64Computation(),
+ dimensions=[1])
+ self._ExecuteAndCompareClose(c, expected=[6, 15])
+
+ def testReduce3DAllPossibleWaysF32(self):
+ input_array = self._MakeSample3DArrayF32()
+
+ def _ReduceAndTest(*dims):
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(input_array),
+ init_value=c.ConstantF32Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF32Computation(),
+ dimensions=dims)
+ self._ExecuteAndCompareClose(
+ c, expected=np.sum(input_array, axis=tuple(dims)))
+
+ _ReduceAndTest(0)
+ _ReduceAndTest(0)
+ _ReduceAndTest(0, 1)
+ _ReduceAndTest(0, 2)
+ _ReduceAndTest(1, 2)
+ _ReduceAndTest(0, 1, 2)
+
+ def testReduce3DAllPossibleWaysF64(self):
+ input_array = self._MakeSample3DArrayF64()
+
+ def _ReduceAndTest(*dims):
+ c = self._NewComputation()
+ c.Reduce(
+ operand=c.Constant(input_array),
+ init_value=c.ConstantF64Scalar(0),
+ computation_to_apply=self._CreateBinaryAddF64Computation(),
+ dimensions=dims)
+ self._ExecuteAndCompareClose(
+ c, expected=np.sum(input_array, axis=tuple(dims)))
+
+ _ReduceAndTest(0)
+ _ReduceAndTest(0)
+ _ReduceAndTest(0, 1)
+ _ReduceAndTest(0, 2)
+ _ReduceAndTest(1, 2)
+ _ReduceAndTest(0, 1, 2)
+
+ def testWhileF32(self):
+ cond = self._CreateTestF32Lt10Computation()
+ body = self._CreateMulF32By2Computation()
+ c = self._NewComputation()
+ init = c.ConstantF32Scalar(1.)
+ c.While(cond, body, init)
+ self._ExecuteAndCompareClose(c, expected=16.)
+
+ def testWhileF64(self):
+ cond = self._CreateTestF64Lt10Computation()
+ body = self._CreateMulF64By2Computation()
+ c = self._NewComputation()
+ init = c.ConstantF64Scalar(1.)
+ c.While(cond, body, init)
+ self._ExecuteAndCompareClose(c, expected=16.)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds
index bddb87f00c..3ff824e5e1 100644
--- a/tensorflow/tf_exported_symbols.lds
+++ b/tensorflow/tf_exported_symbols.lds
@@ -4,3 +4,4 @@
*TF_*
*TFE_*
*nsync_*
+*pywrap_xla*
diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds
index 11f66c5c8b..6b28943f01 100644
--- a/tensorflow/tf_version_script.lds
+++ b/tensorflow/tf_version_script.lds
@@ -5,6 +5,7 @@ tensorflow {
*TF_*;
*TFE_*;
*nsync_*;
+ *pywrap_xla*;
local:
*;
};