diff options
-rw-r--r-- | tensorflow/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/BUILD | 82 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/__init__.py | 0 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/local_computation_builder.cc | 265 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/local_computation_builder.h | 210 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/local_computation_builder.i | 348 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/numpy_bridge.cc | 389 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/numpy_bridge.h | 123 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/xla.i | 18 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client.py | 605 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client_test.py | 898 | ||||
-rw-r--r-- | tensorflow/tf_exported_symbols.lds | 1 | ||||
-rw-r--r-- | tensorflow/tf_version_script.lds | 1 |
14 files changed, 2951 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d80fe5c829..9437bef99f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -411,6 +411,7 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", + "//tensorflow/compiler/xla/python:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d3f292207f..cd69c69889 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -20,6 +20,10 @@ package_group( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -36,6 +40,12 @@ xla_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library_py( + name = "xla_data_proto", # bzl adds a _py suffix + srcs = ["xla_data.proto"], + visibility = ["//visibility:public"], +) + xla_proto_library( name = "xla_proto", srcs = ["xla.proto"], diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD new file mode 100644 index 0000000000..a6b8158671 --- /dev/null +++ b/tensorflow/compiler/xla/python/BUILD @@ -0,0 +1,82 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +py_library( + name = "xla_client", + srcs = ["xla_client.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pywrap_xla", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_test( + name = "xla_client_test", + srcs = ["xla_client_test.py"], + main = "xla_client_test.py", + srcs_version = "PY2AND3", + deps = [ + ":xla_client", + "//tensorflow/python:platform_test", + ], +) + +cc_library( + name = "numpy_bridge", + srcs = ["numpy_bridge.cc"], + hdrs = ["numpy_bridge.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/python:numpy_lib", + ], +) + +cc_library( + name = "local_computation_builder", + srcs = ["local_computation_builder.cc"], + hdrs = ["local_computation_builder.h"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:lib", + ], +) + +tf_py_wrap_cc( + name = "pywrap_xla", + srcs = ["xla.i"], + swig_includes = [ + "local_computation_builder.i", + ], + deps = [ + ":local_computation_builder", + ":numpy_bridge", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/python/__init__.py b/tensorflow/compiler/xla/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tensorflow/compiler/xla/python/__init__.py diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc new file mode 100644 index 0000000000..0b0a53fac7 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -0,0 +1,265 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +namespace swig { + +CompiledLocalComputation::CompiledLocalComputation( + std::unique_ptr<LocalExecutable> executable) + : executable_(std::move(executable)) {} + +std::unique_ptr<Literal> CompiledLocalComputation::Execute( + const std::vector<Literal>& arguments) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + + // Transfer arguments in + std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (const Literal& argument : arguments) { + scoped_buffers.push_back( + client + ->LiteralToShapedBuffer(argument, + /*device_ordinal=*/0, + client->backend().memory_allocator()) + .ConsumeValueOrDie()); + } + + // Execute + std::vector<const ShapedBuffer*> argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(buffer.get()); + } + ExecutableRunOptions options; + options.set_allocator(client->backend().memory_allocator()); + options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + std::unique_ptr<ScopedShapedBuffer> result_buffer = + executable_->Run(argument_buffers, options).ConsumeValueOrDie(); + + // Transfer result out + return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie(); +} + +LocalComputation::LocalComputation(std::unique_ptr<Computation> computation) + : computation_(std::move(computation)) {} + +CompiledLocalComputation* LocalComputation::Compile( + const std::vector<Shape>& argument_shapes) { + std::vector<const Shape*> argument_shape_pointers; + argument_shape_pointers.reserve(argument_shapes.size()); + for (auto& argument_shape : argument_shapes) { + argument_shape_pointers.push_back(&argument_shape); + } + + LocalClient* client = ClientLibrary::LocalClientOrDie(); + ExecutableBuildOptions options; + return new CompiledLocalComputation( + client->Compile(*computation_, argument_shape_pointers, options) + .ValueOrDie()); +} + +const Computation& LocalComputation::computation() const { + return *computation_; +} + +LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) + : builder_(ClientLibrary::LocalClientOrDie(), computation_name) {} + +LocalComputation* LocalComputationBuilder::Build() { + return new LocalComputation(std::unique_ptr<Computation>( + new Computation(builder_.Build().ConsumeValueOrDie()))); +} + +ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { + return builder_.Parameter(parameter_number, shape, name); +} + +std::unique_ptr<Shape> LocalComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + return builder_.GetShape(operand).ConsumeValueOrDie(); +} + +ComputationDataHandle LocalComputationBuilder::ConstantLiteral( + const Literal& literal) { + return builder_.ConstantLiteral(literal); +} + +ComputationDataHandle LocalComputationBuilder::Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { + return builder_.Broadcast(operand, broadcast_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Reshape( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes) { + return builder_.Reshape(operand, dimensions, new_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Slice( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides) { + return builder_.Slice(operand, start_indices, limit_indices, strides); +} + +ComputationDataHandle LocalComputationBuilder::DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes) { + return builder_.DynamicSlice(operand, start_indices, slice_sizes); +} + +ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices) { + return builder_.DynamicUpdateSlice(operand, update, start_indices); +} + +ComputationDataHandle LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, + int64 dimension) { + return builder_.ConcatInDim(operands, dimension); +} + +ComputationDataHandle LocalComputationBuilder::Select( + const ComputationDataHandle& pred, const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false) { + return builder_.Select(pred, on_true, on_false); +} + +ComputationDataHandle LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) { + return builder_.Tuple(elements); +} + +ComputationDataHandle LocalComputationBuilder::GetTupleElement( + const ComputationDataHandle& tuple_data, int64 index) { + return builder_.GetTupleElement(tuple_data, index); +} + +ComputationDataHandle LocalComputationBuilder::Dot( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { + return builder_.Dot(lhs, rhs); +} + +ComputationDataHandle LocalComputationBuilder::ConvertElementType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand, new_element_type); +} + +ComputationDataHandle LocalComputationBuilder::Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) { + return builder_.Call(local_computation.computation(), operands); +} + +ComputationDataHandle LocalComputationBuilder::Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> permutation) { + return builder_.Transpose(operand, permutation); +} + +ComputationDataHandle LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) { + return builder_.Map(operands, local_computation.computation(), dimensions, + static_operands); +} + +ComputationDataHandle LocalComputationBuilder::Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { + return builder_.Reduce(operand, init_value, local_computation.computation(), + dimensions_to_reduce); +} + +ComputationDataHandle LocalComputationBuilder::While( + const LocalComputation& condition, const LocalComputation& body, + const ComputationDataHandle& init) { + return builder_.While(condition.computation(), body.computation(), init); +} + +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig LocalComputationBuilder::method_name args_sig { \ + return builder_.method_name args; \ + } + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand), (operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \ + (lhs, rhs, broadcast_dimensions)) + +_FORWARD_BINOP(Eq) +_FORWARD_BINOP(Ne) +_FORWARD_BINOP(Ge) +_FORWARD_BINOP(Gt) +_FORWARD_BINOP(Lt) +_FORWARD_BINOP(Le) +_FORWARD_BINOP(Add) +_FORWARD_BINOP(Sub) +_FORWARD_BINOP(Mul) +_FORWARD_BINOP(Div) +_FORWARD_BINOP(Rem) +_FORWARD_BINOP(Max) +_FORWARD_BINOP(Min) +_FORWARD_BINOP(And) +_FORWARD_BINOP(Or) +_FORWARD_UNOP(Not) +_FORWARD_UNOP(Abs) +_FORWARD_UNOP(Exp) +_FORWARD_UNOP(Floor) +_FORWARD_UNOP(Ceil) +_FORWARD_UNOP(Log) +_FORWARD_UNOP(Sign) +_FORWARD_UNOP(Cos) +_FORWARD_UNOP(Sin) +_FORWARD_UNOP(Tanh) +_FORWARD_UNOP(SqrtF32) +_FORWARD_UNOP(SquareF32) +_FORWARD_BINOP(Pow) +_FORWARD_UNOP(IsFinite) +_FORWARD_UNOP(ReciprocalF32) +_FORWARD_UNOP(Neg) +_FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h new file mode 100644 index 0000000000..cbab45a5f0 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -0,0 +1,210 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace swig { + +// Wraps a LocalExecutable produced by compiling a +// LocalComputation. The Execute method forwards to that of the +// underlying LocalExecutable, and additionally handles tranferring +// arguments and return values in and back out of the client library's +// local client. This class is intended to be made available to Python +// via SWIG. +class CompiledLocalComputation { + public: + CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable); + std::unique_ptr<Literal> Execute(const std::vector<Literal>& arguments); + + private: + std::unique_ptr<LocalExecutable> executable_; +}; + +// Wraps a Computation produced by a LocalComputationBuilder. The +// Compile method compiles the computation to a (local) executable via +// the client library's local client. This class is intended to be +// made available to Python via SWIG. +class LocalComputation { + public: + LocalComputation(std::unique_ptr<Computation> computation); + CompiledLocalComputation* Compile(const std::vector<Shape>& argument_shapes); + const Computation& computation() const; + + private: + std::unique_ptr<Computation> computation_; +}; + +// Wraps the ComputationBuilder API in order to: +// - Support consumption by SWIG in order to be made available to +// Python. +// - Set up the underlying builder to use the client library's +// LocalClient. +// - Wrap Computations in LocalComputations for Python access. +// - Correspondingly unwrap incoming LocalComputations. +class LocalComputationBuilder { + public: + LocalComputationBuilder(const string& computation_name); + + LocalComputation* Build(); + + ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand); + + ComputationDataHandle ConstantLiteral(const Literal& literal); + + ComputationDataHandle Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + + ComputationDataHandle Reshape(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<int64> new_sizes); + + ComputationDataHandle Slice(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> start_indices, + tensorflow::gtl::ArraySlice<int64> limit_indices, + tensorflow::gtl::ArraySlice<int64> strides); + + ComputationDataHandle DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice<int64> slice_sizes); + + ComputationDataHandle DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices); + + ComputationDataHandle ConcatInDim( + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, + int64 dimension); + + ComputationDataHandle Select(const ComputationDataHandle& pred, + const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false); + + ComputationDataHandle Tuple( + tensorflow::gtl::ArraySlice<ComputationDataHandle> elements); + + ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, + int64 index); + + ComputationDataHandle Dot(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + + ComputationDataHandle Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands); + + ComputationDataHandle Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice<int64> permutation); + + ComputationDataHandle Map( + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<int64> dimensions, + tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands); + + ComputationDataHandle Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce); + + ComputationDataHandle While(const LocalComputation& condition, + const LocalComputation& body, + const ComputationDataHandle& init); + +#define _FORWARD(method_name, return_sig, args_sig) \ + return_sig method_name args_sig; + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions)) + + _FORWARD_BINOP(Eq) + _FORWARD_BINOP(Ne) + _FORWARD_BINOP(Ge) + _FORWARD_BINOP(Gt) + _FORWARD_BINOP(Lt) + _FORWARD_BINOP(Le) + _FORWARD_BINOP(Add) + _FORWARD_BINOP(Sub) + _FORWARD_BINOP(Mul) + _FORWARD_BINOP(Div) + _FORWARD_BINOP(Rem) + _FORWARD_BINOP(Max) + _FORWARD_BINOP(Min) + _FORWARD_BINOP(And) + _FORWARD_BINOP(Or) + _FORWARD_UNOP(Not) + _FORWARD_UNOP(Abs) + _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Floor) + _FORWARD_UNOP(Ceil) + _FORWARD_UNOP(Log) + _FORWARD_UNOP(Sign) + _FORWARD_UNOP(Cos) + _FORWARD_UNOP(Sin) + _FORWARD_UNOP(Tanh) + _FORWARD_UNOP(SqrtF32) + _FORWARD_UNOP(SquareF32) + _FORWARD_BINOP(Pow) + _FORWARD_UNOP(IsFinite) + _FORWARD_UNOP(ReciprocalF32) + _FORWARD_UNOP(Neg) + _FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP + + private: + ComputationBuilder builder_; +}; + +static void DeleteLocalComputation(LocalComputation* computation) { + delete computation; +} + +static void DeleteCompiledLocalComputation( + CompiledLocalComputation* computation) { + delete computation; +} + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i new file mode 100644 index 0000000000..ac8f3e4277 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -0,0 +1,348 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// local_computation_builder.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// ComputationDataHandle <-> long +// ArraySlice<int64> <- sequence of long +// ArraySlice<ComputationDataHandle> <- sequence of long +// Literal <-> (nested tuple of) numpy ndarray +// std::vector<Literal> <- sequence of (nested tuple of) ndarray +// Shape <-> pair holding (dtype, dimensions) +// std::vector<Shape> <- sequence of shape information pairs +// PrimitiveType <- int +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. Also, +// "long" and "int" denote the Python types so named, not C. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// The Python objects corresponding to C++ Shapes have the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, xla_client wraps the long produced for a C++ +// ComputationDataHandle in a Python ComputationDataHandle proto, +// rather than exposing a raw long outside of the client. Similarly, +// the Python pair produced for a C++ Shape is further wrapped in a +// Python class (xla_client.Shape) so as not to expose the raw pair +// externally. +// +// Other SWIG object wrappers (e.g. of LocalComputation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/python/local_computation_builder.h" + +using namespace xla; +using namespace xla::swig; +%} + +// Required to use PyArray_* functions. +%init %{ +tensorflow::ImportNumpy(); +%} + +// ComputationDataHandle + +%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { + const int64 handle = numpy::PyIntOrPyLongToLong($input); + if (handle == -1 && PyErr_Occurred()) { + return NULL; + } + temp.set_handle(handle); + $1 = &temp; +} + +%typemap(out) ComputationDataHandle { + $result = numpy::LongToPyIntOrPyLong($1.handle()); +} + +// ArraySlice<int64> + +%typemap(in) tensorflow::gtl::ArraySlice<int64> + (std::vector<int64> temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// ComputationDataHandle + +%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle> + (std::vector<ComputationDataHandle> temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + return NULL; + } + const int64 handle = numpy::PyIntOrPyLongToLong(py_int); + if (handle == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + temps[i].set_handle(handle); + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (std::unique_ptr<Literal> temp) { + temp = numpy::XlaLiteralFromPyObject($input); + $1 = &*temp; +} + +%typemap(out) std::unique_ptr<Literal> { + $result = numpy::PyObjectFromXlaLiteral(*$1); +} + +%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + temps.push_back(*numpy::XlaLiteralFromPyObject(o)); + Py_DECREF(o); + } + $1 = &temps; +} + +// Shape + +%typemap(in) const Shape& (Shape temp) { + if (!numpy::CheckPyShapeInfo($input)) { + return NULL; + } + temp = numpy::XlaShapeFromPyShapeInfo($input); + $1 = &temp; +} + +%typemap(out) std::unique_ptr<Shape> { + $result = numpy::PyShapeInfoFromXlaShape(*$1); +} + +%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!numpy::CheckPyShapeInfo(o)) { + Py_DECREF(o); + return NULL; + } + temps.push_back(numpy::XlaShapeFromPyShapeInfo(o)); + Py_DECREF(o); + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + return NULL; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + return NULL; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + return NULL; + } + $1 = static_cast<PrimitiveType>(value); +} + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::CompiledLocalComputation; +%unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::LocalComputation; +%unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::Build; +%unignore xla::swig::LocalComputationBuilder::Parameter; +%unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; +%unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Broadcast; +%unignore xla::swig::LocalComputationBuilder::Reshape; +%unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::DynamicSlice; +%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::LocalComputationBuilder::ConcatInDim; +%unignore xla::swig::LocalComputationBuilder::Select; +%unignore xla::swig::LocalComputationBuilder::Tuple; +%unignore xla::swig::LocalComputationBuilder::GetTupleElement; +%unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::Call; +%unignore xla::swig::LocalComputationBuilder::Transpose; +%unignore xla::swig::LocalComputationBuilder::Map; +%unignore xla::swig::LocalComputationBuilder::Reduce; +%unignore xla::swig::LocalComputationBuilder::While; +%unignore xla::swig::LocalComputationBuilder::Eq; +%unignore xla::swig::LocalComputationBuilder::Ne; +%unignore xla::swig::LocalComputationBuilder::Ge; +%unignore xla::swig::LocalComputationBuilder::Gt; +%unignore xla::swig::LocalComputationBuilder::Lt; +%unignore xla::swig::LocalComputationBuilder::Le; +%unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::Add; +%unignore xla::swig::LocalComputationBuilder::Sub; +%unignore xla::swig::LocalComputationBuilder::Mul; +%unignore xla::swig::LocalComputationBuilder::Div; +%unignore xla::swig::LocalComputationBuilder::Rem; +%unignore xla::swig::LocalComputationBuilder::Max; +%unignore xla::swig::LocalComputationBuilder::Min; +%unignore xla::swig::LocalComputationBuilder::And; +%unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Not; +%unignore xla::swig::LocalComputationBuilder::Abs; +%unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Floor; +%unignore xla::swig::LocalComputationBuilder::Ceil; +%unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Sign; +%unignore xla::swig::LocalComputationBuilder::Cos; +%unignore xla::swig::LocalComputationBuilder::Sin; +%unignore xla::swig::LocalComputationBuilder::Tanh; +%unignore xla::swig::LocalComputationBuilder::SqrtF32; +%unignore xla::swig::LocalComputationBuilder::SquareF32; +%unignore xla::swig::LocalComputationBuilder::Pow; +%unignore xla::swig::LocalComputationBuilder::IsFinite; +%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Neg; +%unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DeleteLocalComputation; +%unignore xla::swig::DeleteCompiledLocalComputation; + +%include "tensorflow/compiler/xla/python/local_computation_builder.h" + +%unignoreall diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc new file mode 100644 index 0000000000..b30bdc3669 --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -0,0 +1,389 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { + switch (primitive_type) { + case PRED: + return NPY_BOOL; + case S8: + return NPY_INT8; + case S16: + return NPY_INT16; + case S32: + return NPY_INT32; + case S64: + return NPY_INT64; + case U8: + return NPY_UINT8; + case U16: + return NPY_UINT16; + case U32: + return NPY_UINT32; + case U64: + return NPY_UINT64; + case F16: + return NPY_FLOAT16; + case F32: + return NPY_FLOAT32; + case F64: + return NPY_FLOAT64; + case TUPLE: + return NPY_OBJECT; + default: + LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type; + } +} + +PrimitiveType NumpyTypeToPrimitiveType(int np_type) { + switch (np_type) { + case NPY_BOOL: + return PRED; + case NPY_INT8: + return S8; + case NPY_INT16: + return S16; + case NPY_INT32: + return S32; + case NPY_INT64: + return S64; + case NPY_UINT8: + return U8; + case NPY_UINT16: + return U16; + case NPY_UINT32: + return U32; + case NPY_UINT64: + return U64; + case NPY_FLOAT16: + return F16; + case NPY_FLOAT32: + return F32; + case NPY_FLOAT64: + return F64; + case NPY_OBJECT: + return TUPLE; + default: + LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type; + } +} + +bool NumpyTypeIsValid(int np_type) { + switch (np_type) { + case NPY_BOOL: + case NPY_INT8: + case NPY_INT16: + case NPY_INT32: + case NPY_INT64: + case NPY_UINT8: + case NPY_UINT16: + case NPY_UINT32: + case NPY_UINT64: + case NPY_FLOAT16: + case NPY_FLOAT32: + case NPY_FLOAT64: + case NPY_OBJECT: + return true; + default: + return false; + } +} + +PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { + int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); + PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); + + PyObject* dimensions; + if (ShapeUtil::IsTuple(shape)) { + int num_elements = ShapeUtil::TupleElementCount(shape); + dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + for (int i = 0; i < num_elements; ++i) { + PyTuple_SET_ITEM( + dimensions, i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + } + } else { + int rank = ShapeUtil::Rank(shape); + dimensions = PyTuple_New(rank); + for (int i = 0; i < rank; ++i) { + PyTuple_SET_ITEM(dimensions, i, + LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); + } + } + return PyTuple_Pack(2, np_dtype, dimensions); +} + +// Precondition: o->ob_type == &PyArrayDescr_Type +static int NumpyTypenum(PyObject* o) { + return reinterpret_cast<PyArray_Descr*>(o)->type_num; +} + +bool CheckPyShapeInfo(PyObject* o) { + // The object is a tuple (a pair) + if (!PyTuple_Check(o)) { + PyErr_SetString(PyExc_TypeError, "Shape record must be a tuple"); + return false; + } + if (PyTuple_Size(o) != 2) { + PyErr_SetString(PyExc_ValueError, "Shape record tuple must be of length 2"); + return false; + } + + // It has a first element, which is a numpy dtype object + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + return false; + } + if (first->ob_type != &PyArrayDescr_Type) { + PyErr_SetString( + PyExc_TypeError, + "Shape record does not have a numpy dtype as its first element"); + return false; + } + const int np_type = NumpyTypenum(first); + if (!NumpyTypeIsValid(np_type)) { + PyErr_SetString(PyExc_ValueError, + "Shape record has an invalid integer dtype"); + return false; + } + + // It has a second element, which is a tuple, either of shape + // records or of Python ints + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + return false; + } + if (!PyTuple_Check(second)) { + PyErr_SetString(PyExc_TypeError, + "Shape record does not have a tuple as its second element"); + return false; + } + const int length = PyTuple_Size(second); + const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); + for (int i = 0; i < length; i++) { + PyObject* dimension = PyTuple_GetItem(second, i); + if (element_type == TUPLE) { + if (!CheckPyShapeInfo(dimension)) { + return false; + } + } else if (!CheckPyIntOrLong(dimension)) { + PyErr_SetString(PyExc_TypeError, + "Non-tuple shape record has a non-integer dimension"); + return false; + } + } + + return true; +} + +// Precondition: CheckPyShapeInfo(o) +Shape XlaShapeFromPyShapeInfo(PyObject* o) { + const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0)); + const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); + PyObject* py_dimensions = PyTuple_GetItem(o, 1); + const int length = PyTuple_Size(py_dimensions); + if (element_type == TUPLE) { + std::vector<Shape> subshapes; + subshapes.reserve(length); + for (int i = 0; i < length; i++) { + subshapes.push_back( + XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } else { + std::vector<int64> dimensions(length); + for (int i = 0; i < length; i++) { + dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); + if (dimensions[i] == -1) { + CHECK(!PyErr_Occurred()); + } + } + return ShapeUtil::MakeShape(element_type, dimensions); + } +} + +PyObject* PyObjectFromXlaLiteral(const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + const std::vector<Literal>& tuple_literals = literal.tuple_literals(); + int num_elements = ShapeUtil::TupleElementCount(literal.shape()); + PyObject* tuple = PyTuple_New(num_elements); + for (int i = 0; i < num_elements; i++) { + PyTuple_SET_ITEM(tuple, i, PyObjectFromXlaLiteral(tuple_literals[i])); + } + return tuple; + } else { + int rank = ShapeUtil::Rank(literal.shape()); + std::vector<long> dimensions(rank); // NOLINT - PyArray requires a long* + for (int i = 0; i < rank; i++) { + dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); + } + int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); + PyObject* array = + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); + CopyLiteralToNumpyArray(np_type, literal, + reinterpret_cast<PyArrayObject*>(array)); + return array; + } +} + +std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o) { + if (PyTuple_Check(o)) { + int num_elements = PyTuple_Size(o); + std::vector<std::unique_ptr<Literal>> elements; + elements.reserve(num_elements); + for (int i = 0; i < num_elements; i++) { + PyObject* element = PyTuple_GetItem(o, i); + elements.push_back(XlaLiteralFromPyObject(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } else if (PyArray_Check(o)) { + PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o); + int rank = PyArray_NDIM(py_array); + std::vector<int64> dimensions(rank); + for (int i = 0; i < rank; i++) { + dimensions[i] = PyArray_DIM(py_array, i); + } + int np_type = PyArray_TYPE(py_array); + auto literal = Literal::CreateFromDimensions( + NumpyTypeToPrimitiveType(np_type), dimensions); + CopyNumpyArrayToLiteral(np_type, py_array, literal.get()); + return literal; + } else { + LOG(FATAL) + << "Non-tuple or Numpy array encountered in conversion to XLA literal"; + } +} + +void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal) { + switch (np_type) { + case NPY_BOOL: + CopyNumpyArrayToLiteral<bool>(py_array, literal); + break; + case NPY_INT32: + CopyNumpyArrayToLiteral<int32>(py_array, literal); + break; + case NPY_INT64: + CopyNumpyArrayToLiteral<int64>(py_array, literal); + break; + case NPY_UINT8: + CopyNumpyArrayToLiteral<uint8>(py_array, literal); + break; + case NPY_UINT32: + CopyNumpyArrayToLiteral<uint32>(py_array, literal); + break; + case NPY_UINT64: + CopyNumpyArrayToLiteral<uint64>(py_array, literal); + break; + case NPY_FLOAT16: + CopyNumpyArrayToLiteral<half>(py_array, literal); + break; + case NPY_FLOAT32: + CopyNumpyArrayToLiteral<float>(py_array, literal); + break; + case NPY_FLOAT64: + CopyNumpyArrayToLiteral<double>(py_array, literal); + break; + default: + LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + } +} + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array) { + switch (np_type) { + case NPY_BOOL: + CopyLiteralToNumpyArray<bool>(literal, py_array); + break; + case NPY_INT32: + CopyLiteralToNumpyArray<int32>(literal, py_array); + break; + case NPY_INT64: + CopyLiteralToNumpyArray<int64>(literal, py_array); + break; + case NPY_UINT8: + CopyLiteralToNumpyArray<uint8>(literal, py_array); + break; + case NPY_UINT32: + CopyLiteralToNumpyArray<uint32>(literal, py_array); + break; + case NPY_UINT64: + CopyLiteralToNumpyArray<uint64>(literal, py_array); + break; + case NPY_FLOAT16: + CopyLiteralToNumpyArray<half>(literal, py_array); + break; + case NPY_FLOAT32: + CopyLiteralToNumpyArray<float>(literal, py_array); + break; + case NPY_FLOAT64: + CopyLiteralToNumpyArray<double>(literal, py_array); + break; + default: + LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + } +} + +PyObject* LongToPyIntOrPyLong(long x) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_FromLong(x); +#else + return PyLong_FromLong(x); +#endif +} + +long PyIntOrPyLongToLong(PyObject* o) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_AsLong(o); +#else + return PyLong_AsLong(o); +#endif +} + +bool CheckPyIntOrLong(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyInt_Check(o); +#else + if (!PyLong_Check(o)) { + return false; + } + int overflow = 0; + PyLong_AsLongAndOverflow(o, &overflow); + return (overflow == 0); +#endif +} + +PyObject* PyNumberToPyInt(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyNumber_Int(o); +#else + return PyNumber_Long(o); +#endif +} + +} // namespace numpy + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h new file mode 100644 index 0000000000..4e6ecbb0e8 --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// These functions transform Python/Numpy data structures to XLA data +// structures and vice versa, performing copies where +// appropriate. Python tuples and Numpy ndarrays translate to XLA +// tuples and XLA literals, respectively, and Numpy shape/dtype +// information is translated to XLA shape information. + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ + +#include <algorithm> +#include <memory> + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/python/lib/core/numpy.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy +// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and +// vice versa. +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type); +PrimitiveType NumpyTypeToPrimitiveType(int np_type); + +// Determines whether an integer-encoded Numpy dtype is valid, +// i.e. has a supported conversion to an XLA PrimitiveType. +bool NumpyTypeIsValid(int np_type); + +// Converts XLA shape information into a Python pair of the form +// (numpy dtype, dimensions). If the XLA shape represents a tuple, +// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a +// Python tuple of shape-description pairs, created +// recursively. Otherwise, `dimensions` is a Python tuple-of-integers +// providing the array dimensions. +// +// The return value is a new reference. +PyObject* PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns the outcome of a best-effort check that the Python object +// is a pair of the form (numpy dtype, dimensions), as produced by +// PyShapeInfoFromXlaShape. +bool CheckPyShapeInfo(PyObject* o); + +// Performs the inverse conversion to that of PyShapeInfoFromXlaShape. +// +// The return value is a new reference. +Shape XlaShapeFromPyShapeInfo(PyObject* o); + +// Converts an XLA literal to a Python object, either a Numpy ndarray +// or a nested Python tuple thereof. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +// +// The return value is a new reference. +PyObject* PyObjectFromXlaLiteral(const Literal& literal); + +// Converts a Numpy ndarray or a nested Python tuple thereof to a +// corresponding XLA literal. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +std::unique_ptr<Literal> XlaLiteralFromPyObject(PyObject* o); + +// The following functions copy array data from the buffers underlying Numpy +// ndarrays into those underlying XLA literals, and vice versa. + +void CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal); + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array); + +template <typename NativeT> +void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { + NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array)); + auto dest = literal->GetMutableArraySlice<NativeT>(); + std::copy(source, source + PyArray_SIZE(py_array), dest.data()); +} + +template <typename NativeT> +void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { + NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array)); + auto source = literal.GetArraySlice<NativeT>(); + std::copy(source.begin(), source.end(), dest); +} + +// Workarounds for Python 2 and 3 interop + +PyObject* LongToPyIntOrPyLong(long x); // NOLINT +long PyIntOrPyLongToLong(PyObject* o); // NOLINT +bool CheckPyIntOrLong(PyObject* o); +PyObject* PyNumberToPyInt(PyObject* o); + +} // namespace numpy + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ diff --git a/tensorflow/compiler/xla/python/xla.i b/tensorflow/compiler/xla/python/xla.i new file mode 100644 index 0000000000..1c4021a558 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla.i @@ -0,0 +1,18 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* XLA-wide SWIG wrapper */ + +%include "tensorflow/compiler/xla/python/local_computation_builder.i" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py new file mode 100644 index 0000000000..c75d54856d --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -0,0 +1,605 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An in-process, local XLA client in Python, supporting AOT compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python import pywrap_xla as c_api + +_UNARY_OPS = [ + 'Not', + 'Abs', + 'Exp', + 'Floor', + 'Ceil', + 'Log', + 'Sign', + 'Cos', + 'Sin', + 'Tanh', + 'SqrtF32', + 'SquareF32', + 'IsFinite', + 'ReciprocalF32', + 'Neg', + 'Sort', +] + +_BINARY_OPS = [ + 'Eq', + 'Ne', + 'Ge', + 'Gt', + 'Lt', + 'Le', + 'Add', + 'Sub', + 'Mul', + 'Div', + 'Rem', + 'Max', + 'Min', + 'And', + 'Or', + 'Pow', +] + +# Most functions are snake_case for consistency with other modules, +# whereas method names of ComputationBuilder and LocalComputation are +# CamelCase for consistency with XLA. +# pylint: disable=invalid-name + +XLA_ELEMENT_TYPE_TO_DTYPE = { + xla_data_pb2.F32: np.dtype(np.float32), + xla_data_pb2.F64: np.dtype(np.float64), + xla_data_pb2.S32: np.dtype(np.int32), + xla_data_pb2.S64: np.dtype(np.int64), + xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.TUPLE: np.dtype(np.object), +} + +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(v): k + for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +class Shape(object): + """XLA shape. + + Represents an XLA shape by a corresponding Python/Numpy type and a + list of dimensions, which are themselves Shapes in case this one + represents an XLA tuple. + """ + + def __init__(self, np_dtype, dimensions): + self.np_dtype = np_dtype + self._dimensions = dimensions + + def element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions') + return self._dimensions + + def tuple_shapes(self): + if not self.is_tuple(): + raise ValueError('Shape is not a tuple shape') + return self._dimensions + + @staticmethod + def from_numpy(npval): + + def convert(npval): + if isinstance(npval, tuple): + return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) + else: + return Shape(npval.dtype, np.shape(npval)) + + return convert(require_numpy_array_layout(npval)) + + +def _wrap_shape(shape_info): + dtype, dims = shape_info + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] + if element_type == xla_data_pb2.TUPLE: + dims = [_wrap_shape(subshape_info) for subshape_info in dims] + return Shape(dtype, dims) + + +def _unwrap_shape(shape): + if shape.is_tuple(): + components = tuple( + _unwrap_shape(subshape) for subshape in shape.tuple_shapes()) + else: + components = shape.dimensions() + return (shape.np_dtype, components) + + +def _unwrap_shapes(shapes): + return [_unwrap_shape(shape) for shape in shapes] + + +def _wrap_data_handle(handle): + cdh = xla_data_pb2.ComputationDataHandle() + cdh.handle = handle + return cdh + + +def _unwrap_data_handle(handle_proto): + return handle_proto.handle + + +def _unwrap_data_handles(handle_protos): + return [_unwrap_data_handle(cdh) for cdh in handle_protos] + + +def require_numpy_array_layout(value): + if isinstance(value, tuple): + return tuple(require_numpy_array_layout(x) for x in value) + else: + return np.require(value, requirements=['C', 'A']) + + +class LocalComputation(object): + """Python wrapper for a local XLA Computation. + + A LocalComputation can be executed if it is compiled. Otherwise, it + can still be used as a Computation where required by the + ComputationBuilder methods. + """ + + def __init__(self, c_local_computation, is_compiled): + self.c_local_computation = c_local_computation + self.is_compiled = is_compiled + + # Ensure a reference to C-based destructor for use in __del__. + if is_compiled: + self._delete = c_api.DeleteCompiledLocalComputation + else: + self._delete = c_api.DeleteLocalComputation + + def Compile(self, argument_shapes=()): + if self.is_compiled: + raise ValueError('Attempt to compile a compiled local XLA computation.') + return LocalComputation( + self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)), + is_compiled=True) + + def CompileWithExampleArguments(self, arguments=()): + return self.Compile( + argument_shapes=[Shape.from_numpy(arg) for arg in arguments]) + + def Execute(self, arguments=()): + if not self.is_compiled: + raise ValueError('Cannot execute an uncompiled local XLA computation.') + arguments = tuple(map(require_numpy_array_layout, arguments)) + return self.c_local_computation.Execute(arguments) + + def __del__(self): + self._delete(self.c_local_computation) + + +class ComputationBuilder(object): + """XLA computation builder. + + Enqueues XLA ops in sequence and in order to build a + LocalComputation, which in turn can be compiled into a + CompiledLocalComputation, which in turn can be locally executed. + """ + + # The methods of this class map 1-to-1 onto the XLA C++ + # computation builder API. Therefore, there's no need to laboriously list + # arguments and return values for every method, especially where it's obvious. + # + # pylint: disable=g-doc-return-or-yield + # pylint: disable=g-doc-args + + def __init__(self, name): + self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._parameter_numbering = itertools.count() + + def Build(self): + return LocalComputation(self._client.Build(), is_compiled=False) + + def Constant(self, value): + """Enqueues a constant op onto the computation. + + Args: + value: value for the constant, as a np.array with an explicit dtype set + to one of the supported types. + + Returns: + A ComputationDataHandle message. + """ + value = require_numpy_array_layout(value) + return _wrap_data_handle(self._client.ConstantLiteral(value)) + + def ConstantF32Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float32)) + + def ConstantF64Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float64)) + + def ConstantS32Scalar(self, value): + """Convenience method to enqueue a scalar S32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int32)) + + def ConstantS64Scalar(self, value): + """Convenience method to enqueue a scalar S64 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int64)) + + def ConstantPredScalar(self, value): + """Convenience method to enqueue a scalar PRED constant op. + + Args: + value: a boolean value. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.bool)) + + def ParameterWithShape(self, shape, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation, given a shape. + + Args: + shape: the parameter's shape as a Shape object. + name: optional string name for the parameter. + parameter_num: parameter number in the computation function. If None, + the next linear parameter number is used. The default value capability + can be used for auto-numbering. If you're using auto-numbering for some + parameters, use it for *all* parameters to avoid clashes. + + Returns: + A ComputationDataHandle message. + """ + if name is None: + name = '' + if parameter_num is None: + parameter_num = next(self._parameter_numbering) + + return _wrap_data_handle( + self._client.Parameter( + parameter_num, _unwrap_shape(shape), name.encode('utf8'))) + + def ParameterFromNumpy(self, value, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation. + + Args: + value: a Numpy array, or a nested tuple thereof, from which the + shape is inferred. + name: as in ParameterWithShape. + parameter_num: as in ParameterWithShape. + + Returns: + A ComputationDataHandle message. + """ + return self.ParameterWithShape( + Shape.from_numpy(value), name=name, parameter_num=parameter_num) + + def Broadcast(self, operand, sizes): + """Enqueues a broadcast operation onto the computation. + + Args: + operand: the operand ComputationDataHandle to broadcast. + sizes: an iterable of broadcast sizes. + + Returns: + A ComputationDataHandle representing the added broadcast op. + """ + return _wrap_data_handle( + self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + + def Concatenate(self, operands, dimension): + """Enqueues a concatenate operation onto the computation. + + Args: + operands: the operands to concatenate. + dimension: the dimension in which to perform the concatenation. + + Returns: + A ComputationDataHandle representing the added concatenate op. + """ + return _wrap_data_handle( + self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + + def ConvertElementType(self, operand, new_element_type): + """Enqueues an element type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A ComputationDataHandle representing the added conversion op. + """ + return _wrap_data_handle( + self._client.ConvertElementType( + _unwrap_data_handle(operand), new_element_type)) + + def GetShape(self, operand): + return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + + def GetComputationStats(self): + raise NotImplementedError() + + def Reshape(self, operand, dimensions, new_sizes): + """Reshape op.""" + return _wrap_data_handle( + self._client.Reshape( + _unwrap_data_handle(operand), dimensions, new_sizes)) + + def Trans(self, operand): + """Specialized matrix transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + + def Transpose(self, operand, permutation): + """Transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), permutation)) + + def Select(self, pred, on_true, on_false): + """Element-wise selection op. + + Constructs an output array from elements of two input arrays, based on the + values of a predicate array. + """ + return _wrap_data_handle( + self._client.Select( + _unwrap_data_handle(pred), + _unwrap_data_handle(on_true), + _unwrap_data_handle(on_false))) + + def Slice(self, operand, start_indices, limit_indices, strides=None): + """Enqueues a slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: iterable of N integers containing the starting indices of + the slice for each dimension. + limit_indices: iterable of N integers containing the ending indices + (exclusive) of the slice for each dimension. + strides: optional iterable of N integers containing the stride sizes for + each dimension. + + Returns: + A ComputationDataHandle representing the added Slice op. + """ + if strides is None: + start_indices = list(start_indices) + strides = [1] * len(start_indices) + return _wrap_data_handle( + self._client.Slice( + _unwrap_data_handle(operand), + start_indices, + limit_indices, + strides)) + + def DynamicSlice(self, operand, start_indices, slice_sizes): + """Enqueues a slice op with dynamic start indices onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: ComputationDataHandle for the 1D array of N integers + containing the starting indices of the slice. + slice_sizes: iterable of N integers containing the slice sizes in each + dimension. + + Returns: + A ComputationDataHandle representing the added DynamicSlice op. + """ + return _wrap_data_handle( + self._client.DynamicSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(start_indices), + slice_sizes)) + + def DynamicUpdateSlice(self, operand, update, start_indices): + """Enqueues a dynamic update slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be updated. + update: N dimensional array comprising the slice update. + start_indices: Rank-1 array of N integers comprising the starting indices + of the slice along each dimension. + Returns: + A ComputationDataHandle representing the added DynamicUpdateSlice op. + """ + return _wrap_data_handle( + self._client.DynamicUpdateSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(update), + _unwrap_data_handle(start_indices))) + + def Tuple(self, *ops): + """Enqueues a tuple operation onto the computation. + + Args: + ops: a sequence of tuple operands (each a ComputationDataHandle). + + Returns: + A ComputationDataHandle representing the added Tuple op. + """ + return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + + def GetTupleElement(self, tup, index): + """Enqueues a 'get tuple element' operation onto the computation. + + Args: + tup: the tuple operand (a ComputationDataHandle). + index: numeric index to select from the tuple. + + Returns: + A ComputationDataHandle representing the added GetTupleElement op. + """ + return _wrap_data_handle( + self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + + def Call(self, computation_to_apply, operands): + """Enqueues a call operation onto the computation. + + Args: + computation_to_apply: a Computation object. + operands: an iterable of ComputationDataHandle. The number and types of + operands must match the arity of computation_to_apply. + + Returns: + A ComputationDataHandle representing the added call op. + """ + return _wrap_data_handle( + self._client.Call(computation_to_apply.c_local_computation, + _unwrap_data_handles(operands))) + + def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + """Enqueues a map operation onto the computation. + + Args: + operands: an iterable of ComputationDataHandle. + computation_to_apply: a Computation object. + dimensions: dimensions over which to apply map the function. + static_operands: auxiliary arguments passed to the applied computation. + + Returns: + A ComputationDataHandle representing the added Map op. + """ + return _wrap_data_handle( + self._client.Map( + _unwrap_data_handles(operands), + computation_to_apply.c_local_computation, + dimensions, + _unwrap_data_handles(static_operands))) + + def Reduce(self, operand, init_value, computation_to_apply, dimensions): + """Enqueues a reduction operation onto the computation. + + Args: + operand: reduction operand (ComputationDataHandle). + init_value: reduction initial value (ComputationDataHandle). + computation_to_apply: a Computation object - binary reduction function. + dimensions: sequence of dimensions (integers) to reduce on. + + Returns: + A ComputationDataHandle representing the added Reduce op. + """ + return _wrap_data_handle( + self._client.Reduce( + _unwrap_data_handle(operand), + _unwrap_data_handle(init_value), + computation_to_apply.c_local_computation, + dimensions)) + + def While(self, cond, body, init): + """Enqueues a While operation onto the computation. + + Args: + cond: a Computation for the loop condition, which has type T -> PRED + body: a Computation for the loop body, which has type T -> T + init: an ComputationDataHandle for the initial parameter, which has type T + + Returns: a ComputationDataHandle representing the While operation. + """ + return _wrap_data_handle( + self._client.While(cond.c_local_computation, + body.c_local_computation, + _unwrap_data_handle(init))) + + def Dot(self, lhs, rhs): + """Matrix multiplication between lhs and rhs.""" + return _wrap_data_handle( + self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + + +def _forward_methods_to_local_builder(): + """Forward remaining ComputationBuilder methods to the C API. + + Set up methods, corresponding to unary and binary XLA operations, + whose calls are forwarded in a boilerplate manner to the underlying + LocalComputationBuilder C-extension API. + """ + + def forward_to_local_builder_with_handles(target_method, is_binop=False): + """Generate a forwarding method that wraps/unwraps data handles.""" + + def forward(self, *args, **kwargs): + unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + + if is_binop and len(unwrapped_args) < 3: + unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + + return _wrap_data_handle( + target_method( + self._client, # pylint: disable=protected-access + *unwrapped_args)) + + return forward + + for method_name in _UNARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name)) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + for method_name in _BINARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + +_forward_methods_to_local_builder() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py new file mode 100644 index 0000000000..878cd83edc --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -0,0 +1,898 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Python extension-based XLA client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.xla.python import xla_client +import unittest + + +class LocalComputationTest(unittest.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.ComputationBuilder(name) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + compiled_c = c.Build().CompileWithExampleArguments(arguments) + result = compiled_c.Execute(arguments) + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) + assert_func(result, expected) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) + + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, + expected) + + +def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + +def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + +def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + +def NumpyArrayS64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" + return np.array(*args, dtype=np.int64, **kwargs) + + +def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" + return np.array(*args, dtype=np.bool, **kwargs) + + +class ComputationsWithConstantsTest(LocalComputationTest): + """Tests focusing on Constant ops.""" + + def testConstantScalarSumF32(self): + c = self._NewComputation() + c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumF64(self): + c = self._NewComputation() + c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumS32(self): + c = self._NewComputation() + c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantScalarSumS64(self): + c = self._NewComputation() + c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantVectorMulF32(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorMulF64(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorScalarDivF32(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), + c.ConstantF32Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarDivF64(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), + c.ConstantF64Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarPowF32(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testConstantVectorScalarPowF64(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testBooleanAnd(self): + c = self._NewComputation() + c.And( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) + + def testBooleanOr(self): + c = self._NewComputation() + c.Or( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + + def testSum2DF32(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DF64(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DWith1DBroadcastDim0F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim0F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim1F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testSum2DWith1DBroadcastDim1F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testConstantAxpyF32(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF32Scalar(2), + c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF32([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + def testConstantAxpyF64(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF64Scalar(2), + c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF64([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + +class ParametersTest(LocalComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + def setUp(self): + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) + self.f64_scalar_2 = NumpyArrayF64(2.0) + self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) + self.s32_scalar_3 = NumpyArrayS32(3) + self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) + self.s64_scalar_3 = NumpyArrayS64(3) + self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) + + def testScalarTimesVectorAutonumberF32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f32_scalar_2) + p1 = c.ParameterFromNumpy(self.f32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorAutonumberF64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f64_scalar_2) + p1 = c.ParameterFromNumpy(self.f64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorS32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s32_scalar_3) + p1 = c.ParameterFromNumpy(self.s32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s32_scalar_3, self.s32_4vector], + expected=[30, 45, -6, 21]) + + def testScalarTimesVectorS64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s64_scalar_3) + p1 = c.ParameterFromNumpy(self.s64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s64_scalar_3, self.s64_4vector], + expected=[30, 45, -6, 21]) + + def testScalarMinusVectorExplicitNumberingF32(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + def testScalarMinusVectorExplicitNumberingF64(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + +class SingleOpTest(LocalComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + def testConcatenateF32(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConcatenateF64(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConvertElementType(self): + xla_types = { + np.bool: xla_client.xla_data_pb2.PRED, + np.int32: xla_client.xla_data_pb2.S32, + np.int64: xla_client.xla_data_pb2.S64, + np.float32: xla_client.xla_data_pb2.F32, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.ConvertElementType(x, xla_types[dst_dtype]) + + result = c.Build().Compile().Execute() + expected = np.array(template, dtype=dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype) + + def testDotMatrixVectorF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixVectorF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + c.Not(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=~arr) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Exp(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log(arr)) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Neg(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=-arr) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Floor(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.floor(arr)) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Ceil(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + c.Abs(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.abs(arr)) + + def testTanh(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Tanh(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) + + def testTrans(self): + + def _TransposeAndTest(array): + c = self._NewComputation() + c.Trans(c.Constant(array)) + self._ExecuteAndCompareClose(c, expected=array.T) + + # Test square and non-square matrices in both default (C) and F orders. + for array_fun in [NumpyArrayF32, NumpyArrayF64]: + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) + _TransposeAndTest(array_fun([[1, 2], [4, 5]])) + _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + c.Transpose(c.Constant(array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=expected) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + c.Eq( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + + def testNe(self): + c = self._NewComputation() + c.Ne( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) + + c.Ne( + c.Constant(NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, c, (), expected=[True, False, True, True]) + + def testGt(self): + c = self._NewComputation() + c.Gt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) + + def testGe(self): + c = self._NewComputation() + c.Ge( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) + + def testLt(self): + c = self._NewComputation() + c.Lt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) + + def testLe(self): + c = self._NewComputation() + c.Le( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) + + def testMax(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) + + def testMin(self): + c = self._NewComputation() + c.Min( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + + def testReshape(self): + c = self._NewComputation() + c.Reshape( + c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + dimensions=[0, 1], + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) + + def testSelect(self): + c = self._NewComputation() + c.Select( + c.Constant(NumpyArrayBool([True, False, False, True, False])), + c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), + c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) + + def testSlice(self): + c = self._NewComputation() + c.Slice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], + [3, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicSlice(self): + c = self._NewComputation() + c.DynamicSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([1, 0])), [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + c.DynamicUpdateSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), + c.Constant(NumpyArrayS32([1, 1]))) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) + + def testTuple(self): + c = self._NewComputation() + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))) + result = c.Build().Compile().Execute() + self.assertIsInstance(result, tuple) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + c.GetTupleElement( + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))), 1) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) + + def testBroadcast(self): + c = self._NewComputation() + c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + + +class EmbeddedComputationsTest(LocalComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantS32Computation(self): + """Computation (f32) -> s32 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s32_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantS32Scalar(1) + return c.Build() + + def _CreateConstantS64Computation(self): + """Computation (f64) -> s64 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s64_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantS64Scalar(1) + return c.Build() + + def _CreateConstantF32Computation(self): + """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f32_one") + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantF32Scalar(1.0) + return c.Build() + + def _CreateConstantF64Computation(self): + """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f64_one") + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantF64Scalar(1.0) + return c.Build() + + def _CreateMulF32By2Computation(self): + """Computation (f32) -> f32 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) + return c.Build() + + def _CreateMulF64By2Computation(self): + """Computation (f64) -> f64 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f64_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) + return c.Build() + + def _CreateBinaryAddF32Computation(self): + """Computation (f32, f32) -> f32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryAddF64Computation(self): + """Computation (f64, f64) -> f64 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateBinaryDivF32Computation(self): + """Computation (f32, f32) -> f32 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryDivF64Computation(self): + """Computation (f64, f64) -> f64 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateTestF32Lt10Computation(self): + """Computation (f32) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f32_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) + return c.Build() + + def _CreateTestF64Lt10Computation(self): + """Computation (f64) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f64_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) + return c.Build() + + def _MakeSample3DArrayF32(self): + return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def _MakeSample3DArrayF64(self): + return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def testCallF32(self): + c = self._NewComputation() + c.Call( + self._CreateMulF32By2Computation(), + operands=(c.ConstantF32Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testCallF64(self): + c = self._NewComputation() + c.Call( + self._CreateMulF64By2Computation(), + operands=(c.ConstantF64Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testMapEachElementToS32Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS32Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapEachElementToS64Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS64Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapMulBy2F32(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testMapMulBy2F64(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testSimpleMapChainF32(self): + # Chains a map of constant-f32 with a map of mul-by-2 + c = self._NewComputation() + const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF32Computation(), [0]) + c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testSimpleMapChainF64(self): + # Chains a map of constant-f64 with a map of mul-by-2 + c = self._NewComputation() + const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF64Computation(), [0]) + c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testDivVectorsWithMapF32(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF32Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testDivVectorsWithMapF64(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF64Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testReduce1DtoScalarF32(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce1DtoScalarF64(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce2DTo1DDim0F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim0F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim1F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce2DTo1DDim1F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce3DAllPossibleWaysF32(self): + input_array = self._MakeSample3DArrayF32() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testReduce3DAllPossibleWaysF64(self): + input_array = self._MakeSample3DArrayF64() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testWhileF32(self): + cond = self._CreateTestF32Lt10Computation() + body = self._CreateMulF32By2Computation() + c = self._NewComputation() + init = c.ConstantF32Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + def testWhileF64(self): + cond = self._CreateTestF64Lt10Computation() + body = self._CreateMulF64By2Computation() + c = self._NewComputation() + init = c.ConstantF64Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds index bddb87f00c..3ff824e5e1 100644 --- a/tensorflow/tf_exported_symbols.lds +++ b/tensorflow/tf_exported_symbols.lds @@ -4,3 +4,4 @@ *TF_* *TFE_* *nsync_* +*pywrap_xla* diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds index 11f66c5c8b..6b28943f01 100644 --- a/tensorflow/tf_version_script.lds +++ b/tensorflow/tf_version_script.lds @@ -5,6 +5,7 @@ tensorflow { *TF_*; *TFE_*; *nsync_*; + *pywrap_xla*; local: *; }; |