aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-03-15 14:50:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 16:04:30 -0700
commitdcd71f6343c086ebd5dd4875e57bc92d9465e769 (patch)
tree36733867e186862fc5a09928793ed17f1dac5a8c
parent90c5838c5d8fc672b020e4baa3d5138f3940cd03 (diff)
[XLA] Give Transpose its own Request, rather than piggybacking on ReshapeRequest. Avoids building unnecessary Reshape operators when Transpose was called by the client.
Also avoids building Transpose operators when Reshape has identity transpose dimensions, for example when the client called the variant of ComputationBuilder::Reshape() that does not transpose. Makes the HLO graph emitted by the TF bridge more readable. Change: 150253949
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc22
-rw-r--r--tensorflow/compiler/xla/service/service.cc4
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc68
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
-rw-r--r--tensorflow/compiler/xla/util.cc9
-rw-r--r--tensorflow/compiler/xla/util.h3
-rw-r--r--tensorflow/compiler/xla/xla_data.proto10
7 files changed, 102 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 70afaf2ccb..4b7f9c1822 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -981,19 +981,23 @@ ComputationDataHandle ComputationBuilder::IsFinite(
ComputationDataHandle ComputationBuilder::Transpose(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> permutation) {
- if (!first_error_.ok()) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
- StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
- if (!shape.ok()) {
- // Just early return with the existing error status.
- first_error_ = shape.status();
- return ComputationDataHandle();
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ TransposeRequest* request = op_request.mutable_transpose_request();
+ *request->mutable_operand() = operand;
+ for (int64 dimension : permutation) {
+ request->add_dimensions(dimension);
}
- return Reshape(operand, permutation,
- Permute(InversePermutation(permutation),
- AsInt64Slice(shape.ValueOrDie()->dimensions())));
+ AddOpMetadata(&op_request);
+ OpResponse response;
+
+ VLOG(2) << "making transpose request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::Rev(
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 905675e301..d88315e747 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1370,6 +1370,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
break;
case OpRequest::kTraceRequest:
return computation->AddTraceInstruction(arg->trace_request());
+ case OpRequest::kTransposeRequest:
+ handle_status =
+ computation->AddTransposeInstruction(arg->transpose_request());
+ break;
case OpRequest::kUnaryOpRequest:
handle_status = computation->AddUnaryInstruction(arg->unary_op_request());
break;
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 5b7e253977..a77788e0b6 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -662,6 +662,33 @@ StatusOr<ComputationDataHandle> UserComputation::AddReshapeInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddTransposeInstruction(
+ const TransposeRequest& transpose_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // Fetches and validates the operand.
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookUpRequest(transpose_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(Shape inferred_shape,
+ ShapeInference::InferTransposeShape(
+ operand->output_shape(),
+ AsInt64Slice(transpose_request.dimensions())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ *request.mutable_request()->mutable_transpose_request() = transpose_request;
+
+ VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << transpose_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction(
const SliceRequest& slice_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1498,6 +1525,14 @@ void ConstantVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kTransposeRequest: {
+ const TransposeRequest& transpose_request =
+ request.request().transpose_request();
+ ConstantVisitor(session_computation, transpose_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
case OpRequest::kVariadicOpRequest: {
const VariadicOpRequest& variadic_op_request =
request.request().variadic_op_request();
@@ -2125,15 +2160,32 @@ HloInstruction* ComputationLowerer::Visit(
const ReshapeRequest& reshape_request =
request.request().reshape_request();
HloInstruction* operand = Visit(reshape_request.operand(), visited);
+ HloInstruction* transposed;
+ if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) {
+ transposed = operand;
+ } else {
+ transposed =
+ hlo_builder_.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice(
+ reshape_request.dimensions())),
+ operand->shape()),
+ operand, AsInt64Slice(reshape_request.dimensions())));
+ }
+ hlo_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateReshape(request.output_shape(), transposed));
+ break;
+ }
+
+ case OpRequest::kTransposeRequest: {
+ const TransposeRequest& transpose_request =
+ request.request().transpose_request();
+ HloInstruction* operand = Visit(transpose_request.operand(), visited);
hlo_instruction =
- hlo_builder_.AddInstruction(HloInstruction::CreateReshape(
- request.output_shape(),
- hlo_builder_.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::PermuteDimensions(
- InversePermutation(
- AsInt64Slice(reshape_request.dimensions())),
- operand->shape()),
- operand, AsInt64Slice(reshape_request.dimensions())))));
+ hlo_builder_.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::PermuteDimensions(InversePermutation(AsInt64Slice(
+ transpose_request.dimensions())),
+ operand->shape()),
+ operand, AsInt64Slice(transpose_request.dimensions())));
break;
}
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 55475e727b..fb5425ae61 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -144,6 +144,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddReshapeInstruction(
const ReshapeRequest& reshape_request);
+ // Enqueues a transpose instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddTransposeInstruction(
+ const TransposeRequest& transpose_request);
+
// Enqueues a slice instruction onto this user computation.
StatusOr<ComputationDataHandle> AddSliceInstruction(
const SliceRequest& slice_request);
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index 3ee5dfc949..a711b5035d 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -176,6 +176,15 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
return output;
}
+bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> p) {
+ for (int64 i = 0; i < p.size(); ++i) {
+ if (p[i] != i) {
+ return false;
+ }
+ }
+ return true;
+}
+
PaddingConfig MakeNoPaddingConfig(int64 rank) {
PaddingConfig padding_config;
for (int64 dnum = 0; dnum < rank; ++dnum) {
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 00f8d946f8..55a66a7499 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -183,6 +183,9 @@ std::vector<int64> InversePermutation(
std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
tensorflow::gtl::ArraySlice<int64> p2);
+// Returns true iff permutation == {0, 1, 2, ...}.
+bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation);
+
template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) {
return std::distance(container.begin(),
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 8f63e18140..2bb09c069c 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -511,6 +511,13 @@ message ReshapeRequest {
repeated int64 new_sizes = 4;
}
+message TransposeRequest {
+ ComputationDataHandle operand = 2;
+
+ // The permutation of the operand's dimensions (in the range 0 to n-1).
+ repeated int64 dimensions = 3;
+}
+
message ParameterRequest {
Shape shape = 2;
int64 parameter = 3;
@@ -743,13 +750,14 @@ message OpRequest {
SliceRequest slice_request = 24;
TernaryOpRequest ternary_op_request = 25;
TraceRequest trace_request = 26;
+ TransposeRequest transpose_request = 34;
UnaryOpRequest unary_op_request = 27;
VariadicOpRequest variadic_op_request = 28;
WhileRequest while_request = 29;
SendRequest send_request = 30;
RecvRequest recv_request = 31;
OutfeedRequest outfeed_request = 32;
- // Next: 34
+ // Next: 35
}
}