diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 10:44:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 10:55:14 -0700 |
commit | 419fff9de94ea9573f2e368fd6a68fdf54c59bab (patch) | |
tree | 3ce6bbcbb232da57ace60ebaaddc22971f0273f7 /tensorflow | |
parent | 38f803498f448b2eecdfeccbf2ce609e141e6cca (diff) |
Implement LiteralBase::Slice for all primitive type
PiperOrigin-RevId: 215764305
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/literal.cc | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index deeb140b8f..177f39cc74 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -727,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices, ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); switch (result_shape.element_type()) { - case F32: - return SliceInternal<float>(result_shape, start_indices); + case PRED: + return SliceInternal<bool>(result_shape, start_indices); + case U8: + return SliceInternal<uint8>(result_shape, start_indices); + case U16: + return SliceInternal<uint16>(result_shape, start_indices); + case U32: + return SliceInternal<uint32>(result_shape, start_indices); + case U64: + return SliceInternal<uint64>(result_shape, start_indices); + case S8: + return SliceInternal<int8>(result_shape, start_indices); + case S16: + return SliceInternal<int16>(result_shape, start_indices); + case S32: + return SliceInternal<int32>(result_shape, start_indices); + case S64: + return SliceInternal<int64>(result_shape, start_indices); + case F16: + return SliceInternal<half>(result_shape, start_indices); case BF16: return SliceInternal<bfloat16>(result_shape, start_indices); + case F32: + return SliceInternal<float>(result_shape, start_indices); + case F64: + return SliceInternal<double>(result_shape, start_indices); case C64: return SliceInternal<complex64>(result_shape, start_indices); - case S32: - return SliceInternal<int32>(result_shape, start_indices); - case U32: - return SliceInternal<uint32>(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); |