aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Roy Frostig <frostig@google.com>2018-04-19 23:01:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 23:03:38 -0700
commit70b8d21edcc84818835c9e2940a5df288c309d45 (patch)
tree89c55762017aa02afd09b3c80c7553f678c5e06e
parent2273c4e56334caf31de01c6b6f8f4edd48432972 (diff)
[XLA] Rework the local XLA client's Shape class with separate array and tuple shape constructors.
PiperOrigin-RevId: 193624591
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc20
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py137
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py10
3 files changed, 103 insertions, 64 deletions
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index eec48479c9..dc6f5fe5fc 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -181,16 +181,6 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
PyObjectCppRepr(o).c_str());
};
- auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> {
- PyObject* result =
- PyObject_GetAttrString(o, const_cast<char*>(field.c_str()));
- if (result == nullptr) {
- return error(tensorflow::strings::StrCat(
- "Failed to get attribute of Shape object:", field));
- }
- return result;
- };
-
auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
PyObject* result =
PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
@@ -202,12 +192,16 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
};
PyObject* np_type;
- TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype"));
+ TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype"));
if (np_type->ob_type != &PyArrayDescr_Type) {
- return error("Shape attribute np_dtype is not an integer numpy dtype");
+ return error(
+ "Return value of shape method numpy_dtype "
+ "is not an integer numpy dtype");
}
if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
- return error("Shape attribute np_dtype is not a valid integer numpy dtype");
+ return error(
+ "Return value of shape method numpy_dtype "
+ "is not a valid integer numpy dtype");
}
const PrimitiveType element_type =
NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 9c81f6439d..f6809b6b87 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -166,14 +166,14 @@ class LocalBuffer(object):
self._delete = c_api.DeleteLocalShapedBuffer
@staticmethod
- def from_py(npval, layout_fn=None):
- npval = require_numpy_array_layout(npval)
+ def from_pyval(pyval, layout_fn=None):
+ pyval = require_numpy_array_layout(pyval)
if layout_fn:
- shape = Shape.from_numpy(npval)
+ shape = Shape.from_pyval(pyval)
shape = shape.map_leaves(layout_fn)
else:
shape = None
- return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape))
+ return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(pyval, shape))
def to_py(self):
return self.c_local_shaped_buffer.ToLiteral()
@@ -191,53 +191,104 @@ class LocalBuffer(object):
class Shape(object):
- """XLA shape.
+ """Represents an XLA shape.
- Represents an XLA shape by a corresponding Python/Numpy type and a
- list of dimensions, which are themselves Shapes in case this one
- represents an XLA tuple.
+ A shape is either an array shape, having rank-many integer
+ dimensions and an element type (represented by a Numpy dtype), or it
+ is a tuple shape, having a shape for every tuple component:
+
+ type shape =
+ TupleShape of shape list
+ | ArrayShape of { dimensions: int list; element_type: dtype }
+
+ Callers are expected to instantiate this class only via the static
+ constructors: tuple_shape, array_shape, and from_pyval.
"""
- def __init__(self, np_dtype, dimensions, minor_to_major=None):
+ @staticmethod
+ def tuple_shape(tuple_shapes):
+ """Construct a tuple shape."""
+ if (not isinstance(tuple_shapes, (tuple, list)) or
+ not all(isinstance(t, Shape) for t in tuple_shapes)):
+ raise TypeError('tuple_shapes must be a tuple of Shapes')
+ return Shape(tuple_shapes, tuple)
+
+ @staticmethod
+ def array_shape(element_type, dimensions, minor_to_major=None):
+ """Construct an array shape."""
+ if (not isinstance(dimensions, tuple) or
+ not all(isinstance(i, int) for i in dimensions)):
+ dimensions = tuple(int(i) for i in dimensions)
+ return Shape(dimensions, np.dtype(element_type),
+ minor_to_major=minor_to_major)
+
+ @staticmethod
+ def from_pyval(pyval):
+ def convert(pyval):
+ if isinstance(pyval, tuple):
+ return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
+ else:
+ pyval = require_numpy_array_layout(pyval)
+ return Shape.array_shape(pyval.dtype, np.shape(pyval))
+ return convert(pyval)
+
+ def __init__(self, dimensions, dtype, minor_to_major=None):
assert isinstance(dimensions, tuple)
- self.np_dtype = np_dtype
self._dimensions = dimensions
+ self._dtype = dtype
+ self._is_tuple = dtype == tuple
self._minor_to_major = minor_to_major
self._check_minor_to_major()
def __eq__(self, other):
# pylint: disable=protected-access
- return (self.np_dtype == other.np_dtype and
+ return (self._dtype == other._dtype and
self._dimensions == other._dimensions and
self._minor_to_major == other._minor_to_major)
def __repr__(self):
- return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
- 'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
- self._minor_to_major)
-
- def element_type(self):
- return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
+ return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, '
+ '_is_tuple={!r}), _minor_to_major={!r}').format(
+ self._dtype, self._dimensions, self._is_tuple,
+ self._minor_to_major)
def is_tuple(self):
- return self.element_type() == xla_data_pb2.TUPLE
+ return self._is_tuple
- def dimensions(self):
- if self.is_tuple():
- raise ValueError('Tuple shape has no dimensions')
- return self._dimensions
-
- def minor_to_major(self):
- return self._minor_to_major
+ def is_array(self):
+ return not self._is_tuple
def tuple_shapes(self):
if not self.is_tuple():
- raise ValueError('Shape is not a tuple shape')
+ raise ValueError('not a tuple shape')
+ return self._dimensions
+
+ def numpy_dtype(self):
+ """Like element_type(), but returns dtype('O') in case of a tuple shape."""
+ if self.is_tuple():
+ return np.dtype(np.object)
+ else:
+ return self.element_type()
+
+ def xla_element_type(self):
+ return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())]
+
+ def element_type(self):
+ if not self.is_array():
+ raise ValueError('not an array shape')
+ return self._dtype
+
+ def dimensions(self):
+ if not self.is_array():
+ raise ValueError('not an array shape')
return self._dimensions
def rank(self):
return len(self.dimensions())
+ def minor_to_major(self):
+ return self._minor_to_major
+
def map_leaves(self, f):
"""Map f over each leaf-level array subshape.
@@ -250,7 +301,7 @@ class Shape(object):
"""
if self.is_tuple():
children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
- return Shape(np.dtype('O'), children)
+ return Shape.tuple_shape(children)
else:
mapped = f(self)
return self if mapped is None else mapped
@@ -264,30 +315,24 @@ class Shape(object):
assert sorted(mtm) == range(len(mtm)), self
def update_minor_to_major(self, minor_to_major):
+ if not self.is_array():
+ raise ValueError('not an array shape')
if not isinstance(minor_to_major, tuple):
raise TypeError('minor_to_major must be a tuple')
- updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major)
+ updated = Shape.array_shape(
+ self.element_type(), self.dimensions(), minor_to_major)
updated._check_minor_to_major() # pylint: disable=protected-access
return updated
- @staticmethod
- def from_numpy(npval):
-
- def convert(npval):
- if isinstance(npval, tuple):
- return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval))
- else:
- return Shape(npval.dtype, np.shape(npval))
-
- return convert(require_numpy_array_layout(npval))
-
def _wrap_shape(shape_info):
dtype, dims = shape_info
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
if element_type == xla_data_pb2.TUPLE:
- dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
- return Shape(dtype, dims)
+ shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
+ return Shape.tuple_shape(shapes)
+ else:
+ return Shape.array_shape(dtype, dims)
def _wrap_data_handle(handle):
@@ -420,7 +465,7 @@ class LocalComputation(object):
compile_options=None,
layout_fn=None):
return self.Compile(
- argument_shapes=[Shape.from_numpy(arg) for arg in arguments],
+ argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
compile_options=compile_options,
layout_fn=layout_fn)
@@ -428,7 +473,7 @@ class LocalComputation(object):
"""Execute with Python values as arguments and return value."""
if not self.is_compiled:
raise ValueError('Cannot execute an uncompiled local XLA computation.')
- argument_shapes = [Shape.from_numpy(arg) for arg in arguments]
+ argument_shapes = [Shape.from_pyval(arg) for arg in arguments]
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
@@ -607,7 +652,7 @@ class ComputationBuilder(object):
A ComputationDataHandle message.
"""
return self.ParameterWithShape(
- Shape.from_numpy(value), name=name, parameter_num=parameter_num)
+ Shape.from_pyval(value), name=name, parameter_num=parameter_num)
def Broadcast(self, operand, sizes):
"""Enqueues a broadcast operation onto the computation.
@@ -968,7 +1013,7 @@ class ComputationBuilder(object):
Returns: a ComputationDataHandle to the generated array of F32 values.
"""
- shape = Shape(self.GetShape(mu).np_dtype, dims)
+ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims)
return _wrap_data_handle(
self._client.RngNormal(
_unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
@@ -988,7 +1033,7 @@ class ComputationBuilder(object):
Returns: a ComputationDataHandle to the generated array of values with the
same numeric type (F32, S32, or U32) as the arguments a and b.
"""
- shape = Shape(self.GetShape(a).np_dtype, dims)
+ shape = Shape.array_shape(self.GetShape(a).element_type(), dims)
return _wrap_data_handle(
self._client.RngUniform(
_unwrap_data_handle(a), _unwrap_data_handle(b), shape))
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index d97264ea64..6fe7b242e4 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -319,7 +319,7 @@ class LocalBufferTest(LocalComputationTest):
def _Execute(self, c, arguments):
compiled_c = c.Build().CompileWithExampleArguments(arguments)
- arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments]
+ arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments]
result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers)
return result_buffer.to_py()
@@ -350,7 +350,7 @@ class LocalBufferTest(LocalComputationTest):
c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
arg = NumpyArrayF32(1.11)
compiled_c = c.Build().CompileWithExampleArguments([arg])
- arg_buffer = xla_client.LocalBuffer.from_py(arg)
+ arg_buffer = xla_client.LocalBuffer.from_pyval(arg)
arg_buffer.delete()
with self.assertRaises(ValueError):
compiled_c.ExecuteWithLocalBuffers([arg_buffer])
@@ -1288,7 +1288,7 @@ class EmbeddedComputationsTest(LocalComputationTest):
def testInfeedS32Values(self):
to_infeed = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
- c.Infeed(xla_client.Shape.from_numpy(to_infeed[0]))
+ c.Infeed(xla_client.Shape.from_pyval(to_infeed[0]))
compiled_c = c.Build().CompileWithExampleArguments()
for item in to_infeed:
xla_client.transfer_to_infeed(item)
@@ -1300,7 +1300,7 @@ class EmbeddedComputationsTest(LocalComputationTest):
def testInfeedThenOutfeedS32(self):
to_round_trip = NumpyArrayS32([1, 2, 3, 4])
c = self._NewComputation()
- x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0]))
+ x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0]))
c.Outfeed(x)
compiled_c = c.Build().CompileWithExampleArguments()
@@ -1310,7 +1310,7 @@ class EmbeddedComputationsTest(LocalComputationTest):
execution.start()
xla_client.transfer_to_infeed(want)
got = xla_client.transfer_from_outfeed(
- xla_client.Shape.from_numpy(to_round_trip[0]))
+ xla_client.Shape.from_pyval(to_round_trip[0]))
execution.join()
self.assertEqual(want, got)