aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.h')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h45
1 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 58e8f4e7fa..d0c59fa6f2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -955,6 +956,8 @@ class XlaBuilder {
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+ StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
+ int64 handle) const;
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1024,6 +1027,10 @@ class XlaBuilder {
// The instructions of this computation.
std::vector<HloInstructionProto> instructions_;
+ // A map from XlaOp::Handle to the index in the instructions_ vector where the
+ // instruction is held.
+ tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.
@@ -2112,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
- return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2129,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*LiteralUtil::CreateR1(values));
+ return ConstantLiteral(LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -2189,12 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
template <typename NativeT>
XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+ return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -2207,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
inline XlaOp ConstantR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp ConstantR2(XlaBuilder* builder,
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+ return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
@@ -2221,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateFromArray<NativeT>(values));
+ LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
@@ -2236,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
const Array2D<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- builder,
- *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+ builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
const Array2D<NativeT>& values) {
return ConstantLiteral(builder,
- *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+ LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
@@ -2253,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
const Layout& layout) {
return ConstantLiteral(
builder,
- *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>