diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla')
37 files changed, 118 insertions, 100 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 40e32f2e75..fd31c26544 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -162,7 +162,7 @@ cc_library( ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -202,7 +202,7 @@ cc_library( ], visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu_internal", @@ -285,6 +285,7 @@ tf_cc_test( deps = [ ":tf2xla", ":tf2xla_proto", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", @@ -327,7 +328,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", @@ -364,6 +365,7 @@ tf_cc_test( ], deps = [ ":common", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/core:framework", "//tensorflow/core:test", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index a8eb7d942d..d1e8982877 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -114,6 +114,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -159,7 +160,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -175,7 +176,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -210,6 +211,7 @@ tf_kernel_library( ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client/lib:arithmetic", diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ee2c920453..ba3b1c9dab 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 2c76bcee25..81f42e504e 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index a020ebc729..22a45b2a11 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector<xla::XlaOp> args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::Literal::CreateR1<int64>(input_shape.dim_sizes()))); + &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::Literal::CreateR1<int64>(output_shape.dim_sizes()))); + &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::Literal::CreateR0<int32>(dim))); + xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index a81f5fddf6..12d9cb9bac 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 46fae59ad4..be7f2bce8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 909783ecb3..ed1d1c6610 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index a4ba6c748a..f4b804e546 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index e0ca8dd8e2..354fec9be7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 037c422258..ec15b4cc7a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 76924c6a01..27ab3e1bf5 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index bc3d0bf5df..25a5bcbe1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index ca74cf2450..242638f981 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 591e61b4c8..df91900570 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 2f650ce305..26326f18b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 9962f1207d..1ddcb08c8e 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index bef6161e85..b62a6e778d 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index 0e5d58ecba..f951127bb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index febac82873..bb27b5d56f 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 340165bac6..9413a30a6c 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index dfa3c0595a..20fa03746c 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -40,7 +40,7 @@ cc_library( ":triangular_solve", ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -73,7 +73,7 @@ cc_library( deps = [ ":util", ":while_loop", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -92,7 +92,7 @@ cc_library( deps = [ ":batch_dot", ":util", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -111,7 +111,7 @@ xla_test( deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -133,6 +133,7 @@ cc_library( srcs = ["util.cc"], hdrs = ["util.h"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -151,7 +152,7 @@ xla_test( ":batch_dot", ":util", "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index f9f3a8c8cf..3c4eec081b 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -84,7 +84,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, dimensions.push_back(y_shape.dimensions(y_outer_dim)); return xla::Broadcast( xla::ConstantLiteral(builder, - xla::Literal::Zero(x_shape.element_type())), + xla::LiteralUtil::Zero(x_shape.element_type())), dimensions); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index cc840de393..35b137aa2c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 85e3d3ab85..6a5be1c2be 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -114,7 +114,7 @@ xla::StatusOr<xla::XlaOp> XlaScatter( auto buffer = loop_vars[2]; auto zero_index = xla::ConstantLiteral( - body_builder, xla::Literal::Zero(indices_shape.element_type())); + body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. xla::XlaOp index; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 588afaac65..ce0f28db8f 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index d5ffc1498e..f1bff6037b 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index fdc8bfca49..a6f5d346cb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -19,6 +19,7 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -28,6 +29,13 @@ limitations under the License. namespace tensorflow { +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { + return xla::Broadcast( + xla::ConstantLiteral(builder, + xla::LiteralUtil::Zero(shape.element_type())), + xla::AsInt64Slice(shape.dimensions())); +} + xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, double value) { switch (type) { @@ -56,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::Literal::CreateR0<uint8>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value)); break; case xla::U32: - literal = std::move(*xla::Literal::CreateR0<uint32>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value)); break; case xla::U64: - literal = std::move(*xla::Literal::CreateR0<uint64>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value)); break; case xla::S8: - literal = std::move(*xla::Literal::CreateR0<int8>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value)); break; case xla::S32: - literal = std::move(*xla::Literal::CreateR0<int32>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value)); break; case xla::S64: - literal = std::move(*xla::Literal::CreateR0<int64>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value)); break; case xla::F32: - literal = std::move(*xla::Literal::CreateR0<float>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<float>(value)); break; case xla::F64: - literal = std::move(*xla::Literal::CreateR0<double>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<double>(value)); break; case xla::C64: - literal = std::move(*xla::Literal::CreateR0<complex64>(value)); + literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value)); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -89,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: literal = std::move( - *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value))); + *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value))); break; case xla::F16: - literal = std::move( - *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value))); + literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>( + static_cast<xla::half>(value))); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index 7d0f2222a9..442fe92c34 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 7cc88f34d2..574e70ddee 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -100,8 +100,9 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( std::vector<xla::XlaOp> updated_values; updated_values.reserve(values.size()); updated_values.push_back(xla::Add( - iteration, xla::ConstantLiteral( - body_builder, xla::Literal::One(num_iterations_type)))); + iteration, + xla::ConstantLiteral(body_builder, + xla::LiteralUtil::One(num_iterations_type)))); values.remove_prefix(1); TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs, @@ -113,8 +114,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( std::vector<xla::XlaOp> values; values.reserve(initial_values.size() + 1); - values.push_back( - xla::ConstantLiteral(builder, xla::Literal::Zero(num_iterations_type))); + values.push_back(xla::ConstantLiteral( + builder, xla::LiteralUtil::Zero(num_iterations_type))); values.insert(values.end(), initial_values.begin(), initial_values.end()); TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index b43405a1a4..2fb66913ad 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index ab7e861f33..0610a57029 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index f3d6787daa..a3404c2b3d 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector<int64> int64_values = {1, 2, 3}; std::unique_ptr<xla::Literal> int64_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice<int64>(int64_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector<int32> int32_values = {10, 11}; std::unique_ptr<xla::Literal> int32_values_literal = - xla::Literal::CreateR1(gtl::ArraySlice<int32>(int32_values)); + xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 84c133ffab..f0b30dcf4e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -73,8 +74,8 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::Literal::CreateR0<int32>(10); - auto y_literal = xla::Literal::CreateR0<int32>(32); + auto x_literal = xla::LiteralUtil::CreateR0<int32>(10); + auto y_literal = xla::LiteralUtil::CreateR0<int32>(32); auto x_global_or = client->TransferToServer(*x_literal); auto y_global_or = client->TransferToServer(*y_literal); TF_EXPECT_OK(x_global_or.status()); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 07af8ef54b..6f76816a86 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -206,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({-3, 101}); + xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -222,9 +222,9 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({4, 143}); + xla::LiteralUtil::CreateR1<int32>({4, 143}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -306,7 +306,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -317,9 +317,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({-7, -42}); + xla::LiteralUtil::CreateR1<int32>({-7, -42}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get()}); + xla::LiteralUtil::MakeTuple({expected0.get()}); EXPECT_TRUE( xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -341,7 +341,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -351,11 +351,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7); + std::unique_ptr<xla::Literal> expected0 = + xla::LiteralUtil::CreateR0<int32>(7); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({-7, -42}); + xla::LiteralUtil::CreateR1<int32>({-7, -42}); std::unique_ptr<xla::Literal> expected = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } @@ -569,11 +570,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> input_base = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> input_grad2 = - xla::Literal::CreateR1<int32>({-3, 101}); + xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::Literal> input = - xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*input).ConsumeValueOrDie(); @@ -583,17 +584,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { std::unique_ptr<xla::Literal> actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42); + std::unique_ptr<xla::Literal> output_read = + xla::LiteralUtil::CreateR0<int32>(42); std::unique_ptr<xla::Literal> output_base = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> output_grad1 = - xla::Literal::CreateR1<int32>({0, 1}); + xla::LiteralUtil::CreateR1<int32>({0, 1}); std::unique_ptr<xla::Literal> output_grad2 = - xla::Literal::CreateR1<int32>({-3, 101}); - std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple( + xla::LiteralUtil::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple( {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -796,9 +798,9 @@ TEST_F(XlaCompilerTest, Variables) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({7, 42}); + xla::LiteralUtil::CreateR1<int32>({7, 42}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({-3, 101}); + xla::LiteralUtil::CreateR1<int32>({-3, 101}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -812,11 +814,11 @@ TEST_F(XlaCompilerTest, Variables) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({5, 144}); + xla::LiteralUtil::CreateR1<int32>({5, 144}); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({4, 143}); + xla::LiteralUtil::CreateR1<int32>({4, 143}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -884,9 +886,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -900,11 +902,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}}); + xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}}); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } @@ -953,9 +955,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3}); std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404}); std::unique_ptr<xla::GlobalData> param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr<xla::GlobalData> param1_data = @@ -969,11 +971,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({27, 67, 35, 402}); + xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402}); std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({26, 66, 34, 401}); + xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401}); std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index fd39a58ce6..0dea366476 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index edbc5e95a8..4d1b3b1a13 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -94,13 +94,13 @@ xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis, xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::Zero(type)); + return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type)); } xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return xla::ConstantLiteral(b, xla::Literal::One(type)); + return xla::ConstantLiteral(b, xla::LiteralUtil::One(type)); } xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, |