diff options
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index deeb140b8f..177f39cc74 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -727,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices, ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); switch (result_shape.element_type()) { - case F32: - return SliceInternal<float>(result_shape, start_indices); + case PRED: + return SliceInternal<bool>(result_shape, start_indices); + case U8: + return SliceInternal<uint8>(result_shape, start_indices); + case U16: + return SliceInternal<uint16>(result_shape, start_indices); + case U32: + return SliceInternal<uint32>(result_shape, start_indices); + case U64: + return SliceInternal<uint64>(result_shape, start_indices); + case S8: + return SliceInternal<int8>(result_shape, start_indices); + case S16: + return SliceInternal<int16>(result_shape, start_indices); + case S32: + return SliceInternal<int32>(result_shape, start_indices); + case S64: + return SliceInternal<int64>(result_shape, start_indices); + case F16: + return SliceInternal<half>(result_shape, start_indices); case BF16: return SliceInternal<bfloat16>(result_shape, start_indices); + case F32: + return SliceInternal<float>(result_shape, start_indices); + case F64: + return SliceInternal<double>(result_shape, start_indices); case C64: return SliceInternal<complex64>(result_shape, start_indices); - case S32: - return SliceInternal<int32>(result_shape, start_indices); - case U32: - return SliceInternal<uint32>(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); |