diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/while_loop.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/lib/while_loop.cc | 52 |
1 files changed, 26 insertions, 26 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 495d9c6078..09ce594930 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -20,24 +20,24 @@ limitations under the License. namespace tensorflow { -xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop( +xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice<xla::ComputationDataHandle> initial_values, - StringPiece name, xla::ComputationBuilder* builder) { + gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name, + xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector<xla::Shape> var_shapes; var_shapes.reserve(arity); - for (const xla::ComputationDataHandle& input : initial_values) { + for (const xla::XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); - var_shapes.push_back(std::move(*shape)); + var_shapes.push_back(std::move(shape)); } xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, - xla::ComputationBuilder* builder) { - std::vector<xla::ComputationDataHandle> elements(arity); + auto unpack_tuple = [](xla::XlaOp tuple, int arity, + xla::XlaBuilder* builder) { + std::vector<xla::XlaOp> elements(arity); for (int i = 0; i < arity; ++i) { elements[i] = builder->GetTupleElement(tuple, i); } @@ -45,20 +45,20 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop( }; // Build the condition. - std::unique_ptr<xla::ComputationBuilder> cond_builder = + std::unique_ptr<xla::XlaBuilder> cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); - TF_ASSIGN_OR_RETURN( - auto result, + TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), - cond_builder.get())); + cond_builder.get()) + .status()); } TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr<xla::ComputationBuilder> body_builder = + std::unique_ptr<xla::XlaBuilder> body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); @@ -78,38 +78,38 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop( return unpack_tuple(outputs, arity, builder); } -xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex( +xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice<xla::ComputationDataHandle> initial_values, - StringPiece name, xla::ComputationBuilder* builder) { - auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values, - xla::ComputationBuilder* cond_builder) - -> xla::StatusOr<xla::ComputationDataHandle> { + gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name, + xla::XlaBuilder* builder) { + auto while_cond_fn = + [&](gtl::ArraySlice<xla::XlaOp> values, + xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> { return cond_builder->Lt( values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> { - xla::ComputationDataHandle iteration = values[0]; + auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values, + xla::XlaBuilder* body_builder) + -> xla::StatusOr<std::vector<xla::XlaOp>> { + xla::XlaOp iteration = values[0]; - std::vector<xla::ComputationDataHandle> updated_values; + std::vector<xla::XlaOp> updated_values; updated_values.reserve(values.size()); updated_values.push_back(body_builder->Add( iteration, body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs, + TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector<xla::ComputationDataHandle> values; + std::vector<xla::XlaOp> values; values.reserve(initial_values.size() + 1); values.push_back( builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); |