diff options
author | Peter Hawkins <phawkins@google.com> | 2018-06-18 09:16:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-18 09:23:18 -0700 |
commit | 8ecf506fb8464dd273ce59f512f5e20d37dd5cfd (patch) | |
tree | 86463ca85a8382743d040000e7135fa135c16f32 /tensorflow/python/lib | |
parent | 3db3e50bb0c02d6f0c7284d50bc31e97ebfc96e5 (diff) |
[TF:XLA] Add a XlaSort operator that directly wraps the Sort HLO.
Merge XLA-specific operator registrations into a single file rather than having many tiny files.
In passing, register a fill function for bfloat16 numpy type; needed for the np.arange() call in the sort unit test.
PiperOrigin-RevId: 201005718
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r-- | tensorflow/python/lib/core/bfloat16.cc | 11 | ||||
-rw-r--r-- | tensorflow/python/lib/core/bfloat16_test.py | 14 |
2 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc index 77fa2c1f66..fde3a83770 100644 --- a/tensorflow/python/lib/core/bfloat16.cc +++ b/tensorflow/python/lib/core/bfloat16.cc @@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) { return x != static_cast<bfloat16>(0); } +int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) { + bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw); + const float start(buffer[0]); + const float delta = static_cast<float>(buffer[1]) - start; + for (npy_intp i = 2; i < length; ++i) { + buffer[i] = static_cast<bfloat16>(start + i * delta); + } + return 0; +} + // NumPy casts // Performs a NumPy array cast from type 'From' to 'To'. @@ -548,6 +558,7 @@ bool Initialize() { NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN; NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap; NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero; + NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill; Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr); diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py index 09d4b01fa4..bc928cd9e5 100644 --- a/tensorflow/python/lib/core/bfloat16_test.py +++ b/tensorflow/python/lib/core/bfloat16_test.py @@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase): np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)), atol=2e-2) + def testArange(self): + self.assertAllEqual( + np.arange(100, dtype=np.float32).astype(bfloat16), + np.arange(100, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16), + np.arange(-10.5, 7.8, 0.5, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16), + np.arange(-0., -7., -0.25, dtype=bfloat16)) + self.assertAllEqual( + np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16), + np.arange(-16384., 16384., 64., dtype=bfloat16)) + if __name__ == "__main__": test.main() |