aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-14 11:18:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 11:21:42 -0800
commit9bbd25da077881696875447c3081f96c20e8728c (patch)
treebc55d97845e0acd017d90cc077bdfd8e89d4d049
parent9277bb73a926684d4346a56fec6c117873a9a84a (diff)
Enable bfloat16 tests and add a filter for currently
failed tests. PiperOrigin-RevId: 179069257
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/xla_test.py105
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc2
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h8
-rw-r--r--tensorflow/core/kernels/split_op.cc2
-rw-r--r--tensorflow/python/framework/test_util.py6
9 files changed, 100 insertions, 33 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 654dc15e86..905dd9fc7b 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -547,7 +547,7 @@ class BinaryOpsTest(XLATestCase):
self._testDivision(dtype)
def testFloatDivision(self):
- for dtype in self.float_types + self.complex_types:
+ for dtype in self.float_types | self.complex_types:
self._testDivision(dtype)
def _testRemainder(self, dtype):
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index ac039e0162..a62925a181 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -330,8 +330,7 @@ class TensorArrayTest(xla_test.XLATestCase):
# Find two different floating point types, create an array of
# the first type, but try to read the other type.
if len(self.float_types) > 1:
- dtype1 = self.float_types[0]
- dtype2 = self.float_types[1]
+ dtype1, dtype2 = list(self.float_types)[:2]
with self.test_session(), self.test_scope():
ta = tensor_array_ops.TensorArray(
dtype=dtype1, tensor_array_name="foo", size=3)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 0da7442a24..b0623c0fbc 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -573,7 +573,8 @@ class UnaryOpsTest(XLATestCase):
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
- types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types
+ types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) |
+ self.complex_tf_types)
for shape in shapes:
for src_type in types:
for dst_type in types:
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 0be127997e..7e1f5c76ed 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -53,41 +53,100 @@ class XLATestCase(test.TestCase):
super(XLATestCase, self).__init__(method_name)
self.device = FLAGS.test_device
self.has_custom_call = (self.device == 'XLA_CPU')
- self.all_tf_types = [
+ self._all_tf_types = set([
dtypes.as_dtype(types_pb2.DataType.Value(name))
for name in FLAGS.types.split(',')
- ]
- self.int_tf_types = [
- dtype for dtype in self.all_tf_types if dtype.is_integer
- ]
- self.float_tf_types = [
- dtype for dtype in self.all_tf_types if dtype.is_floating
- ]
- self.complex_tf_types = [
- dtype for dtype in self.all_tf_types if dtype.is_complex
- ]
- self.numeric_tf_types = (
- self.int_tf_types + self.float_tf_types + self.complex_tf_types)
-
- self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
- self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types]
- self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types]
- self.complex_types = [
+ ])
+ self.int_tf_types = set([
+ dtype for dtype in self._all_tf_types if dtype.is_integer
+ ])
+ self._float_tf_types = set([
+ dtype for dtype in self._all_tf_types if dtype.is_floating
+ ])
+ self.complex_tf_types = set([
+ dtype for dtype in self._all_tf_types if dtype.is_complex
+ ])
+ self._numeric_tf_types = set(
+ self.int_tf_types | self._float_tf_types | self.complex_tf_types)
+
+ self._all_types = set(
+ [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
+ self._float_types = set(
+ [dtype.as_numpy_dtype for dtype in self._float_tf_types])
+ self.complex_types = set([
dtype.as_numpy_dtype for dtype in self.complex_tf_types
- ]
- self.numeric_types = self.int_types + self.float_types + self.complex_types
+ ])
+ self._numeric_types = set(
+ self.int_types | self._float_types | self.complex_types)
# Parse the manifest file, if any, into a regex identifying tests to
# disable
self.disabled_regex = None
+ self._method_types_filter = dict()
+ # TODO(xpan): Make it text proto if it doesn't scale.
+ # Each line of the manifest file specifies an entry. The entry can be
+ # 1) TestNameRegex // E.g. CumprodTest.* Or
+ # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
+ # The 1) disables the entire test. While 2) only filter some numeric types
+ # so that they are not used in those tests.
+
if FLAGS.disabled_manifest is not None:
comments_re = re.compile('#.*$')
manifest_file = open(FLAGS.disabled_manifest, 'r')
- lines = manifest_file.read().splitlines()
- lines = [comments_re.sub('', l).strip() for l in lines]
- self.disabled_regex = re.compile('|'.join(lines))
+ disabled_tests = []
+ disabled_method_types = []
+ for l in manifest_file.read().splitlines():
+ entry = comments_re.sub('', l).strip().split(' ')
+ if len(entry) == 1:
+ disabled_tests.append(entry[0])
+ elif len(entry) == 2:
+ disabled_method_types.append(
+ (entry[0], entry[1].strip().split(',')))
+ else:
+ raise ValueError('Bad entry in manifest file.')
+
+ self.disabled_regex = re.compile('|'.join(disabled_tests))
+ for method, types in disabled_method_types:
+ self._method_types_filter[method] = set([
+ dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
+ for name in types])
manifest_file.close()
+ @property
+ def all_tf_types(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ tf_types = set([dtypes.as_dtype(t)
+ for t in self._method_types_filter.get(name, set())])
+ return self._all_tf_types - tf_types
+
+ @property
+ def float_types(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ return self._float_types - self._method_types_filter.get(name, set())
+
+ @property
+ def float_tf_types(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ return self._float_tf_types - self._method_types_filter.get(name, set())
+
+ @property
+ def numeric_tf_types(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ tf_types = set([dtypes.as_dtype(t)
+ for t in self._method_types_filter.get(name, set())])
+ return self._numeric_tf_types - tf_types
+
+ @property
+ def numeric_types(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ return self._numeric_types - self._method_types_filter.get(name, set())
+
+ @property
+ def all_types(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ return self._all_types - self._method_types_filter.get(name, set())
+
def setUp(self):
super(XLATestCase, self).setUp()
name = '{}.{}'.format(type(self).__name__, self._testMethodName)
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 943248aedb..ce24b61b5d 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -28,7 +28,7 @@ limitations under the License.
namespace tensorflow {
xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- xla::Shape& shape) {
+ const xla::Shape& shape) {
return builder->Broadcast(
builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
xla::AsInt64Slice(shape.dimensions()));
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index 8fba6b5cf2..fb138b4f73 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -25,7 +25,7 @@ namespace tensorflow {
// Returns a zero-filled tensor with shape `shape`.
xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
- xla::Shape& shape);
+ const xla::Shape& shape);
// Returns a floating point scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 2959d2ab69..8bfd9758f7 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -45,11 +45,11 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT"
extern const char* const DEVICE_XLA_CPU;
extern const char* const DEVICE_XLA_GPU;
-constexpr std::array<DataType, 3> kFloatTypes = {
- {DT_HALF, DT_FLOAT, DT_DOUBLE}};
-constexpr std::array<DataType, 8> kNumericTypes = {
+constexpr std::array<DataType, 4> kFloatTypes = {
+ {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
+constexpr std::array<DataType, 9> kNumericTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64}};
+ DT_COMPLEX64, DT_BFLOAT16}};
constexpr std::array<DataType, 8> kCpuAllTypes = {
{DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 58e1a73be6..094ba8bb86 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -360,6 +360,8 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
REGISTER_SPLIT(quint8);
+// TODO(xpan): Merge bfloat16 into TF_CALL_ALL_TYPES
+REGISTER_SPLIT(bfloat16);
#undef REGISTER_SPLIT
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index ae3b6c584a..509c5ec8d6 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -50,6 +50,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@@ -1108,6 +1109,7 @@ class TensorFlowTestCase(googletest.TestCase):
"""
a = self._GetNdArray(a)
b = self._GetNdArray(b)
+ # types with lower tol are put later to overwrite previous ones.
if (a.dtype == np.float32 or b.dtype == np.float32 or
a.dtype == np.complex64 or b.dtype == np.complex64):
rtol = max(rtol, float_rtol)
@@ -1115,6 +1117,10 @@ class TensorFlowTestCase(googletest.TestCase):
if a.dtype == np.float16 or b.dtype == np.float16:
rtol = max(rtol, half_rtol)
atol = max(atol, half_atol)
+ if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
+ b.dtype == dtypes.bfloat16.as_numpy_dtype):
+ rtol = max(rtol, half_rtol)
+ atol = max(atol, half_atol)
self.assertAllClose(a, b, rtol=rtol, atol=atol)