diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.h')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 58e8f4e7fa..d0c59fa6f2 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" @@ -955,6 +956,8 @@ class XlaBuilder { HloInstructionProto* instr); StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const; + StatusOr<const HloInstructionProto*> LookUpInstructionByHandle( + int64 handle) const; // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -1024,6 +1027,10 @@ class XlaBuilder { // The instructions of this computation. std::vector<HloInstructionProto> instructions_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + tensorflow::gtl::FlatMap<int64, int64> handle_to_index_; + // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of // that XlaComputation. @@ -2112,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template <typename NativeT> XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value)); + return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) { - return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values)); + return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values)); } template <typename NativeT> @@ -2129,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR2( std::initializer_list<std::initializer_list<NativeT>> values) { - return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values)); + return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); + LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) { - return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values)); + return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D<NativeT>& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); + LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) { - return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values)); + return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values)); } template <typename NativeT> XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D<NativeT>& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); } template <typename NativeT> @@ -2189,12 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) { template <typename NativeT> XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value)); + return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value)); } template <typename NativeT> XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values)); } template <typename NativeT> @@ -2207,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template <typename NativeT> XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list<std::initializer_list<NativeT>> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values)); } template <typename NativeT> @@ -2221,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array<NativeT>& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateFromArray<NativeT>(values)); + LiteralUtil::CreateFromArray<NativeT>(values)); } template <typename NativeT> @@ -2236,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D<NativeT>& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout)); } template <typename NativeT> XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D<NativeT>& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateR2FromArray2D<NativeT>(values)); + LiteralUtil::CreateR2FromArray2D<NativeT>(values)); } template <typename NativeT> @@ -2253,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout)); } template <typename NativeT> |