aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_support.h
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-02-12 11:26:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-12 11:30:18 -0800
commitfabf6ddede109bbf18115718224449c314bcf92a (patch)
treeb36dd4f1ee9b42e81fb12a448f2dfd867d7180ec /tensorflow/compiler/xla/service/bfloat16_support.h
parent075931641e9147f0faf16e0ce2b76525620e1be0 (diff)
[XLA] An HLO pass that folds BF16 F32 conversions: if an HLO already supports BF16 input/output, conversions before/after it will be removed and the HLO's input/output types will be converted to BF16.
Also updates HloVerifier to allow mixed precision if requested. If an HLO has both both F32 and BF16 inputs, ShapeInference will use F32 as the output type. PiperOrigin-RevId: 185407143
Diffstat (limited to 'tensorflow/compiler/xla/service/bfloat16_support.h')
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_support.h60
1 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h
new file mode 100644
index 0000000000..29f662d22b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/bfloat16_support.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+
+class BFloat16Support {
+ public:
+ BFloat16Support() {}
+ virtual ~BFloat16Support() {}
+
+ // Returns whether the backend supports BF16 operand for the HLO instruction
+ // at the given index.
+ virtual bool SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const;
+
+ // Returns whether the backend supports BF16 output for the HLO instruction.
+ virtual bool SupportsBF16Output(const HloInstruction& hlo) const;
+
+ // Returns whether the backend support mixed precision: the operands, output,
+ // and parameters/output of the called computations can have different
+ // precisions (BF16 and F32).
+ virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const;
+
+ // Returns whether the given HLO inherits its BF16 operand precision at the
+ // given index, so even if the output is F32, elements in the output that
+ // depend on the BF16 operand will still have BF16 effective precision even if
+ // they have F32 format. Similarly, this also means if the output is BF16 then
+ // increasing the operand precision from BF16 to F32 will not change the
+ // output. This typically includes HLOs that pass elements from the operand to
+ // the output without arithmetic operations.
+ static bool EffectiveOperandPrecisionIsOutputPrecision(
+ const HloInstruction& hlo, int64 operand_index);
+
+ // Returns if the backend only uses BF16 precision for the operand at the
+ // specified index, even if the operand is F32.
+ virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
+ int64 operand_index) const;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_