aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 10:44:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 10:55:14 -0700
commit419fff9de94ea9573f2e368fd6a68fdf54c59bab (patch)
tree3ce6bbcbb232da57ace60ebaaddc22971f0273f7 /tensorflow/compiler
parent38f803498f448b2eecdfeccbf2ce609e141e6cca (diff)
Implement LiteralBase::Slice for all primitive type
PiperOrigin-RevId: 215764305
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/literal.cc30
1 files changed, 24 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index deeb140b8f..177f39cc74 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -727,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
switch (result_shape.element_type()) {
- case F32:
- return SliceInternal<float>(result_shape, start_indices);
+ case PRED:
+ return SliceInternal<bool>(result_shape, start_indices);
+ case U8:
+ return SliceInternal<uint8>(result_shape, start_indices);
+ case U16:
+ return SliceInternal<uint16>(result_shape, start_indices);
+ case U32:
+ return SliceInternal<uint32>(result_shape, start_indices);
+ case U64:
+ return SliceInternal<uint64>(result_shape, start_indices);
+ case S8:
+ return SliceInternal<int8>(result_shape, start_indices);
+ case S16:
+ return SliceInternal<int16>(result_shape, start_indices);
+ case S32:
+ return SliceInternal<int32>(result_shape, start_indices);
+ case S64:
+ return SliceInternal<int64>(result_shape, start_indices);
+ case F16:
+ return SliceInternal<half>(result_shape, start_indices);
case BF16:
return SliceInternal<bfloat16>(result_shape, start_indices);
+ case F32:
+ return SliceInternal<float>(result_shape, start_indices);
+ case F64:
+ return SliceInternal<double>(result_shape, start_indices);
case C64:
return SliceInternal<complex64>(result_shape, start_indices);
- case S32:
- return SliceInternal<int32>(result_shape, start_indices);
- case U32:
- return SliceInternal<uint32>(result_shape, start_indices);
default:
LOG(FATAL) << "not yet implemented: "
<< PrimitiveType_Name(result_shape.element_type());