From c4ac9855539dc228881707c69a9ef2fe703dadd4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 10 Jul 2018 13:38:20 -0700 Subject: [XLA] tweak LocalComputation.Compile in XLA Python client so that layout_fn is optional PiperOrigin-RevId: 204003372 --- tensorflow/compiler/xla/python/xla_client.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 27aee634ba..e2b6eaa096 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -461,14 +461,16 @@ class LocalComputation(object): if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] - result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) result_shape = result_shape.map_leaves(layout_fn) - compile_options = compile_options or CompileOptions() - compile_options.result_shape = result_shape + + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) -- cgit v1.2.3