aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/lib
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-18 09:16:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 09:23:18 -0700
commit8ecf506fb8464dd273ce59f512f5e20d37dd5cfd (patch)
tree86463ca85a8382743d040000e7135fa135c16f32 /tensorflow/python/lib
parent3db3e50bb0c02d6f0c7284d50bc31e97ebfc96e5 (diff)
[TF:XLA] Add a XlaSort operator that directly wraps the Sort HLO.
Merge XLA-specific operator registrations into a single file rather than having many tiny files. In passing, register a fill function for bfloat16 numpy type; needed for the np.arange() call in the sort unit test. PiperOrigin-RevId: 201005718
Diffstat (limited to 'tensorflow/python/lib')
-rw-r--r--tensorflow/python/lib/core/bfloat16.cc11
-rw-r--r--tensorflow/python/lib/core/bfloat16_test.py14
2 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 77fa2c1f66..fde3a83770 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
return x != static_cast<bfloat16>(0);
}
+int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
+ bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
+ const float start(buffer[0]);
+ const float delta = static_cast<float>(buffer[1]) - start;
+ for (npy_intp i = 2; i < length; ++i) {
+ buffer[i] = static_cast<bfloat16>(start + i * delta);
+ }
+ return 0;
+}
+
// NumPy casts
// Performs a NumPy array cast from type 'From' to 'To'.
@@ -548,6 +558,7 @@ bool Initialize() {
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
+ NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index 09d4b01fa4..bc928cd9e5 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase):
np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
atol=2e-2)
+ def testArange(self):
+ self.assertAllEqual(
+ np.arange(100, dtype=np.float32).astype(bfloat16),
+ np.arange(100, dtype=bfloat16))
+ self.assertAllEqual(
+ np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
+ np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
+ self.assertAllEqual(
+ np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
+ np.arange(-0., -7., -0.25, dtype=bfloat16))
+ self.assertAllEqual(
+ np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
+ np.arange(-16384., 16384., 64., dtype=bfloat16))
+
if __name__ == "__main__":
test.main()