diff options
author | 2018-04-19 23:01:07 -0700 | |
---|---|---|
committer | 2018-04-19 23:03:38 -0700 | |
commit | 70b8d21edcc84818835c9e2940a5df288c309d45 (patch) | |
tree | 89c55762017aa02afd09b3c80c7553f678c5e06e | |
parent | 2273c4e56334caf31de01c6b6f8f4edd48432972 (diff) |
[XLA] Rework the local XLA client's Shape class with separate array and tuple shape constructors.
PiperOrigin-RevId: 193624591
-rw-r--r-- | tensorflow/compiler/xla/python/numpy_bridge.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client.py | 137 | ||||
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client_test.py | 10 |
3 files changed, 103 insertions, 64 deletions
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index eec48479c9..dc6f5fe5fc 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -181,16 +181,6 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) { PyObjectCppRepr(o).c_str()); }; - auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> { - PyObject* result = - PyObject_GetAttrString(o, const_cast<char*>(field.c_str())); - if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to get attribute of Shape object:", field)); - } - return result; - }; - auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> { PyObject* result = PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr); @@ -202,12 +192,16 @@ StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) { }; PyObject* np_type; - TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype")); + TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype")); if (np_type->ob_type != &PyArrayDescr_Type) { - return error("Shape attribute np_dtype is not an integer numpy dtype"); + return error( + "Return value of shape method numpy_dtype " + "is not an integer numpy dtype"); } if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { - return error("Shape attribute np_dtype is not a valid integer numpy dtype"); + return error( + "Return value of shape method numpy_dtype " + "is not a valid integer numpy dtype"); } const PrimitiveType element_type = NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 9c81f6439d..f6809b6b87 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -166,14 +166,14 @@ class LocalBuffer(object): self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_py(npval, layout_fn=None): - npval = require_numpy_array_layout(npval) + def from_pyval(pyval, layout_fn=None): + pyval = require_numpy_array_layout(pyval) if layout_fn: - shape = Shape.from_numpy(npval) + shape = Shape.from_pyval(pyval) shape = shape.map_leaves(layout_fn) else: shape = None - return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(pyval, shape)) def to_py(self): return self.c_local_shaped_buffer.ToLiteral() @@ -191,53 +191,104 @@ class LocalBuffer(object): class Shape(object): - """XLA shape. + """Represents an XLA shape. - Represents an XLA shape by a corresponding Python/Numpy type and a - list of dimensions, which are themselves Shapes in case this one - represents an XLA tuple. + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + + Callers are expected to instantiate this class only via the static + constructors: tuple_shape, array_shape, and from_pyval. """ - def __init__(self, np_dtype, dimensions, minor_to_major=None): + @staticmethod + def tuple_shape(tuple_shapes): + """Construct a tuple shape.""" + if (not isinstance(tuple_shapes, (tuple, list)) or + not all(isinstance(t, Shape) for t in tuple_shapes)): + raise TypeError('tuple_shapes must be a tuple of Shapes') + return Shape(tuple_shapes, tuple) + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None): + """Construct an array shape.""" + if (not isinstance(dimensions, tuple) or + not all(isinstance(i, int) for i in dimensions)): + dimensions = tuple(int(i) for i in dimensions) + return Shape(dimensions, np.dtype(element_type), + minor_to_major=minor_to_major) + + @staticmethod + def from_pyval(pyval): + def convert(pyval): + if isinstance(pyval, tuple): + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + pyval = require_numpy_array_layout(pyval) + return Shape.array_shape(pyval.dtype, np.shape(pyval)) + return convert(pyval) + + def __init__(self, dimensions, dtype, minor_to_major=None): assert isinstance(dimensions, tuple) - self.np_dtype = np_dtype self._dimensions = dimensions + self._dtype = dtype + self._is_tuple = dtype == tuple self._minor_to_major = minor_to_major self._check_minor_to_major() def __eq__(self, other): # pylint: disable=protected-access - return (self.np_dtype == other.np_dtype and + return (self._dtype == other._dtype and self._dimensions == other._dimensions and self._minor_to_major == other._minor_to_major) def __repr__(self): - return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' - 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, - self._minor_to_major) - - def element_type(self): - return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' + '_is_tuple={!r}), _minor_to_major={!r}').format( + self._dtype, self._dimensions, self._is_tuple, + self._minor_to_major) def is_tuple(self): - return self.element_type() == xla_data_pb2.TUPLE + return self._is_tuple - def dimensions(self): - if self.is_tuple(): - raise ValueError('Tuple shape has no dimensions') - return self._dimensions - - def minor_to_major(self): - return self._minor_to_major + def is_array(self): + return not self._is_tuple def tuple_shapes(self): if not self.is_tuple(): - raise ValueError('Shape is not a tuple shape') + raise ValueError('not a tuple shape') + return self._dimensions + + def numpy_dtype(self): + """Like element_type(), but returns dtype('O') in case of a tuple shape.""" + if self.is_tuple(): + return np.dtype(np.object) + else: + return self.element_type() + + def xla_element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())] + + def element_type(self): + if not self.is_array(): + raise ValueError('not an array shape') + return self._dtype + + def dimensions(self): + if not self.is_array(): + raise ValueError('not an array shape') return self._dimensions def rank(self): return len(self.dimensions()) + def minor_to_major(self): + return self._minor_to_major + def map_leaves(self, f): """Map f over each leaf-level array subshape. @@ -250,7 +301,7 @@ class Shape(object): """ if self.is_tuple(): children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) - return Shape(np.dtype('O'), children) + return Shape.tuple_shape(children) else: mapped = f(self) return self if mapped is None else mapped @@ -264,30 +315,24 @@ class Shape(object): assert sorted(mtm) == range(len(mtm)), self def update_minor_to_major(self, minor_to_major): + if not self.is_array(): + raise ValueError('not an array shape') if not isinstance(minor_to_major, tuple): raise TypeError('minor_to_major must be a tuple') - updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) + updated = Shape.array_shape( + self.element_type(), self.dimensions(), minor_to_major) updated._check_minor_to_major() # pylint: disable=protected-access return updated - @staticmethod - def from_numpy(npval): - - def convert(npval): - if isinstance(npval, tuple): - return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) - else: - return Shape(npval.dtype, np.shape(npval)) - - return convert(require_numpy_array_layout(npval)) - def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] if element_type == xla_data_pb2.TUPLE: - dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) - return Shape(dtype, dims) + shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims) + return Shape.tuple_shape(shapes) + else: + return Shape.array_shape(dtype, dims) def _wrap_data_handle(handle): @@ -420,7 +465,7 @@ class LocalComputation(object): compile_options=None, layout_fn=None): return self.Compile( - argument_shapes=[Shape.from_numpy(arg) for arg in arguments], + argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, layout_fn=layout_fn) @@ -428,7 +473,7 @@ class LocalComputation(object): """Execute with Python values as arguments and return value.""" if not self.is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - argument_shapes = [Shape.from_numpy(arg) for arg in arguments] + argument_shapes = [Shape.from_pyval(arg) for arg in arguments] if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes @@ -607,7 +652,7 @@ class ComputationBuilder(object): A ComputationDataHandle message. """ return self.ParameterWithShape( - Shape.from_numpy(value), name=name, parameter_num=parameter_num) + Shape.from_pyval(value), name=name, parameter_num=parameter_num) def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. @@ -968,7 +1013,7 @@ class ComputationBuilder(object): Returns: a ComputationDataHandle to the generated array of F32 values. """ - shape = Shape(self.GetShape(mu).np_dtype, dims) + shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) return _wrap_data_handle( self._client.RngNormal( _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) @@ -988,7 +1033,7 @@ class ComputationBuilder(object): Returns: a ComputationDataHandle to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ - shape = Shape(self.GetShape(a).np_dtype, dims) + shape = Shape.array_shape(self.GetShape(a).element_type(), dims) return _wrap_data_handle( self._client.RngUniform( _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index d97264ea64..6fe7b242e4 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -319,7 +319,7 @@ class LocalBufferTest(LocalComputationTest): def _Execute(self, c, arguments): compiled_c = c.Build().CompileWithExampleArguments(arguments) - arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments] + arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments] result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) return result_buffer.to_py() @@ -350,7 +350,7 @@ class LocalBufferTest(LocalComputationTest): c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) arg = NumpyArrayF32(1.11) compiled_c = c.Build().CompileWithExampleArguments([arg]) - arg_buffer = xla_client.LocalBuffer.from_py(arg) + arg_buffer = xla_client.LocalBuffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) @@ -1288,7 +1288,7 @@ class EmbeddedComputationsTest(LocalComputationTest): def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - c.Infeed(xla_client.Shape.from_numpy(to_infeed[0])) + c.Infeed(xla_client.Shape.from_pyval(to_infeed[0])) compiled_c = c.Build().CompileWithExampleArguments() for item in to_infeed: xla_client.transfer_to_infeed(item) @@ -1300,7 +1300,7 @@ class EmbeddedComputationsTest(LocalComputationTest): def testInfeedThenOutfeedS32(self): to_round_trip = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0])) + x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0])) c.Outfeed(x) compiled_c = c.Build().CompileWithExampleArguments() @@ -1310,7 +1310,7 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( - xla_client.Shape.from_numpy(to_round_trip[0])) + xla_client.Shape.from_pyval(to_round_trip[0])) execution.join() self.assertEqual(want, got) |