diff options
5 files changed, 117 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 18f6643121..37f1eada2b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -298,6 +298,13 @@ ComputationDataHandle LocalComputationBuilder::Broadcast( return builder_.Broadcast(operand, broadcast_sizes); } +ComputationDataHandle LocalComputationBuilder::Pad( + const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand, padding_value, padding_config); +} + ComputationDataHandle LocalComputationBuilder::Reshape( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice<int64> dimensions, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 038ef1d6f0..e5503cd52f 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -133,6 +133,10 @@ class LocalComputationBuilder { const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes); + ComputationDataHandle Pad(const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config); + ComputationDataHandle Reshape(const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice<int64> dimensions, tensorflow::gtl::ArraySlice<int64> new_sizes); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 3826172199..3178925960 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -31,6 +31,7 @@ limitations under the License. // std::vector<Shape> <- sequence of shape information pairs // PrimitiveType <- int // ArraySlice<pair<int64, in64>> <- sequence of int pairs +// PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // // Arrows indicate whether a conversion only ever occurs in one @@ -460,6 +461,48 @@ tensorflow::ImportNumpy(); $1 = temps; } +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + return NULL; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + return NULL; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + return NULL; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + // ConvolutionDimensionNumbers %typemap(in) const ConvolutionDimensionNumbers& @@ -608,6 +651,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; %unignore xla::swig::LocalComputationBuilder::ConstantR0; %unignore xla::swig::LocalComputationBuilder::Broadcast; +%unignore xla::swig::LocalComputationBuilder::Pad; %unignore xla::swig::LocalComputationBuilder::Reshape; %unignore xla::swig::LocalComputationBuilder::Collapse; %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 2727bbdf99..5455adafcd 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -541,6 +541,36 @@ class ComputationBuilder(object): def GetComputationStats(self): raise NotImplementedError() + def Pad(self, operand, padding_value, padding_config): + """Enqueues a Pad operation onto the computation. + + Args: + operand: ComputationDataHandle representing the array to pad. + padding_value: ComputationDataHandle representing the scalar pad value. + padding_config: either an xla_data_pb2.PaddingConfig or a list of integer + triples (edge_padding_low, edge_padding_high, interior_padding) + representing the configuration of the padding operation. + + Returns: + A ComputationDataHandle representing the added pad op. + """ + if not isinstance(padding_config, xla_data_pb2.PaddingConfig): + padding_config = self._GetPaddingConfigFromTriples(padding_config) + return _wrap_data_handle( + self._client.Pad(_unwrap_data_handle(operand), + _unwrap_data_handle(padding_value), + padding_config)) + + def _GetPaddingConfigFromTriples(self, triples): + """Create PaddingConfig proto from list of triples of integers.""" + padding_config = xla_data_pb2.PaddingConfig() + for lo, hi, interior in triples: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + return padding_config + def Reshape(self, operand, dimensions, new_sizes): """Reshape op.""" return _wrap_data_handle( diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index d0d9c73cc9..c0413b9bbc 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -644,6 +644,38 @@ class SingleOpTest(LocalComputationTest): c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + def testPad(self): + c = self._NewComputation() + c.Pad( + c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + c.Constant(NumpyArrayF32(0.0)), + [(1, 2, 1), (0, 1, 0)]) + self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], + [1.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.xla_data_pb2.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + c.Pad( + c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + c.Constant(NumpyArrayF32(0.0)), + padding_config) + self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], + [1.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + def testReshape(self): c = self._NewComputation() c.Reshape( |