aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/while_loop.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/lib/while_loop.cc')
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc52
1 files changed, 26 insertions, 26 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 495d9c6078..09ce594930 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -20,24 +20,24 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder) {
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
var_shapes.reserve(arity);
- for (const xla::ComputationDataHandle& input : initial_values) {
+ for (const xla::XlaOp& input : initial_values) {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
- var_shapes.push_back(std::move(*shape));
+ var_shapes.push_back(std::move(shape));
}
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
// Unpacks a tuple into its component parts.
- auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity,
- xla::ComputationBuilder* builder) {
- std::vector<xla::ComputationDataHandle> elements(arity);
+ auto unpack_tuple = [](xla::XlaOp tuple, int arity,
+ xla::XlaBuilder* builder) {
+ std::vector<xla::XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) {
elements[i] = builder->GetTupleElement(tuple, i);
}
@@ -45,20 +45,20 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
};
// Build the condition.
- std::unique_ptr<xla::ComputationBuilder> cond_builder =
+ std::unique_ptr<xla::XlaBuilder> cond_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
{
auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
- TF_ASSIGN_OR_RETURN(
- auto result,
+ TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
- cond_builder.get()));
+ cond_builder.get())
+ .status());
}
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
// Build the body.
- std::unique_ptr<xla::ComputationBuilder> body_builder =
+ std::unique_ptr<xla::XlaBuilder> body_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
{
auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
@@ -78,38 +78,38 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
return unpack_tuple(outputs, arity, builder);
}
-xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
+xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
- gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
- StringPiece name, xla::ComputationBuilder* builder) {
- auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
- xla::ComputationBuilder* cond_builder)
- -> xla::StatusOr<xla::ComputationDataHandle> {
+ gtl::ArraySlice<xla::XlaOp> initial_values, StringPiece name,
+ xla::XlaBuilder* builder) {
+ auto while_cond_fn =
+ [&](gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
return cond_builder->Lt(
values[0],
IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
};
- auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
- xla::ComputationBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
- xla::ComputationDataHandle iteration = values[0];
+ auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ xla::XlaOp iteration = values[0];
- std::vector<xla::ComputationDataHandle> updated_values;
+ std::vector<xla::XlaOp> updated_values;
updated_values.reserve(values.size());
updated_values.push_back(body_builder->Add(
iteration,
body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
values.remove_prefix(1);
- TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
+ TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(),
body_outputs.end());
return updated_values;
};
- std::vector<xla::ComputationDataHandle> values;
+ std::vector<xla::XlaOp> values;
values.reserve(initial_values.size() + 1);
values.push_back(
builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));