diff options
author | 2018-01-02 15:51:47 -0800 | |
---|---|---|
committer | 2018-01-02 15:54:53 -0800 | |
commit | 02a66fe2afa8c5e1273e235bd0dd405b8108d34f (patch) | |
tree | afd33b531dcfab5d633fff4e240d65cf219e61fb | |
parent | 425a71083fddcbdeaf415443e9c5bec2eb356a4b (diff) |
Add default constructor to ServiceExecutableRunOptions.
PiperOrigin-RevId: 180604410
-rw-r--r-- | tensorflow/compiler/xla/literal_util.cc | 26 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service_executable_run_options.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 2 |
3 files changed, 17 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index f493460e79..3e909f76f9 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1210,84 +1210,84 @@ Literal::GetMutableArraySlice<bfloat16>() { template <> tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const { - CHECK_EQ(shape().element_type(), PRED); + CHECK_EQ(shape().element_type(), PRED) << ShapeUtil::HumanString(shape()); return tensorflow::gtl::ArraySlice<bool>( reinterpret_cast<const bool*>(preds().data()), preds().size()); } template <> tensorflow::gtl::ArraySlice<uint8> Literal::GetArraySlice<uint8>() const { - CHECK_EQ(shape().element_type(), U8); + CHECK_EQ(shape().element_type(), U8) << ShapeUtil::HumanString(shape()); return tensorflow::gtl::ArraySlice<uint8>( reinterpret_cast<const uint8*>(u8s().data()), u8s().size()); } template <> tensorflow::gtl::ArraySlice<int8> Literal::GetArraySlice<int8>() const { - CHECK_EQ(shape().element_type(), S8); + CHECK_EQ(shape().element_type(), S8) << ShapeUtil::HumanString(shape()); return tensorflow::gtl::ArraySlice<int8>( reinterpret_cast<const int8*>(u8s().data()), u8s().size()); } template <> tensorflow::gtl::ArraySlice<uint16> Literal::GetArraySlice<uint16>() const { - CHECK_EQ(shape().element_type(), U16); + CHECK_EQ(shape().element_type(), U16) << ShapeUtil::HumanString(shape()); return tensorflow::gtl::ArraySlice<uint16>(u16s().data(), u16s().size()); } template <> tensorflow::gtl::ArraySlice<int16> Literal::GetArraySlice<int16>() const { - CHECK_EQ(shape().element_type(), S16); + CHECK_EQ(shape().element_type(), S16) << ShapeUtil::HumanString(shape()); return tensorflow::gtl::ArraySlice<int16>(s16s().data(), s16s().size()); } template <> tensorflow::gtl::ArraySlice<uint32> Literal::GetArraySlice<uint32>() const { - CHECK_EQ(shape().element_type(), U32); + CHECK_EQ(shape().element_type(), U32) << ShapeUtil::HumanString(shape()); return u32s(); } template <> tensorflow::gtl::ArraySlice<uint64> Literal::GetArraySlice<uint64>() const { - CHECK_EQ(shape().element_type(), U64); + CHECK_EQ(shape().element_type(), U64) << ShapeUtil::HumanString(shape()); return u64s(); } template <> tensorflow::gtl::ArraySlice<int32> Literal::GetArraySlice<int32>() const { - CHECK_EQ(shape().element_type(), S32); + CHECK_EQ(shape().element_type(), S32) << ShapeUtil::HumanString(shape()); return s32s(); } template <> tensorflow::gtl::ArraySlice<int64> Literal::GetArraySlice<int64>() const { - CHECK_EQ(shape().element_type(), S64); + CHECK_EQ(shape().element_type(), S64) << ShapeUtil::HumanString(shape()); return s64s(); } template <> tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const { - CHECK_EQ(shape().element_type(), F64); + CHECK_EQ(shape().element_type(), F64) << ShapeUtil::HumanString(shape()); return f64s(); } template <> tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const { - CHECK_EQ(shape().element_type(), F16); + CHECK_EQ(shape().element_type(), F16) << ShapeUtil::HumanString(shape()); return tensorflow::gtl::ArraySlice<half>(f16s().data(), f16s().size() / sizeof(half)); } template <> tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const { - CHECK_EQ(shape().element_type(), BF16); + CHECK_EQ(shape().element_type(), BF16) << ShapeUtil::HumanString(shape()); return {bf16s().data(), bf16s().size()}; } template <> tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>() const { - CHECK_EQ(shape().element_type(), C64); + CHECK_EQ(shape().element_type(), C64) << ShapeUtil::HumanString(shape()); return c64s(); } diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 017e5ef09e..6c1f8feac7 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -30,6 +30,9 @@ class ServiceExecutableRunOptions { using StreamBorrower = std::function<StatusOr<Pool<perftools::gputools::Stream>::SmartPtr>(int)>; + ServiceExecutableRunOptions() + : ServiceExecutableRunOptions(ExecutableRunOptions()) {} + explicit ServiceExecutableRunOptions( ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 48b7515ecf..2c1b1d22ad 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -353,7 +353,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)); + CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate<decltype(shape.dimensions().begin()), int64>( shape.dimensions().begin(), shape.dimensions().end(), 1LL, |