aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_proto_util.h
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-03-05 13:44:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 13:48:53 -0800
commit36b3c94a99704c8e1973ae5c043aec4870ae84ff (patch)
tree1b40289dbc0d6de7ec8ff3e5f15823754c867bd3 /tensorflow/compiler/xla/service/hlo_proto_util.h
parent5368a1a3af94c6b49dd51d0d85cb3702f484daa7 (diff)
Add methods for extracting the shapes of the entry computation from an HloProto.
PiperOrigin-RevId: 187915821
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_proto_util.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.h9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h
index 320288fdb9..3d9c375cd5 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.h
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.h
@@ -35,6 +35,15 @@ HloProto MakeHloProto(const HloModule& module,
// will not be included in the output.
HloProto MakeHloProto(const HloModule& module);
+// Returns the shapes of the parameters of the entry computation. Shape pointers
+// refer to shapes inside of the given HloProto.
+StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
+ const HloProto& hlo_proto);
+
+// Returns the shape of the output of the entry computation. The shape pointer
+// refers to the output shape inside of the given HloProto.
+StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_