aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/literal_test_util.h
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2017-12-06 20:08:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 20:12:02 -0800
commitf75481874fb7314c907b1770ea04c851b9ec07d4 (patch)
treeec31c8eb73c2bfb4dbaf411e36fbbb9fa0ba16e5 /tensorflow/compiler/xla/tests/literal_test_util.h
parent846a73f9f336e54a02c12388ac76a0aa8700543a (diff)
Tuple literal conversions for BF16 and F32
PiperOrigin-RevId: 178191335
Diffstat (limited to 'tensorflow/compiler/xla/tests/literal_test_util.h')
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 6e4add2690..bf8c92f16d 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -59,10 +59,14 @@ class LiteralTestUtil {
static void AssertEqualShapesAndLayouts(const Shape& expected,
const Shape& actual);
- // Converts a bfloat16 literal to a float literal.
+ // If the given literal's data type is bfloat16, converts it to a float
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
static std::unique_ptr<Literal> ConvertBF16ToF32(const Literal& bf16_literal);
- // Converts a float literal to a bfloat16 literal.
+ // If the given literal's data type is float, converts it to a bfloat16
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
static std::unique_ptr<Literal> ConvertF32ToBF16(const Literal& f32_literal);
// Asserts that the expected and actual literals are (bitwise) equal for all