aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-10 13:38:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 13:42:10 -0700
commitc4ac9855539dc228881707c69a9ef2fe703dadd4 (patch)
tree5a9152abc7ce0a1353ec370988a91b55ae6a5f3b
parenta67c15a3291b96b54c6457bde832e1221a96381a (diff)
[XLA] tweak LocalComputation.Compile in XLA Python client so that layout_fn is optional
PiperOrigin-RevId: 204003372
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 27aee634ba..e2b6eaa096 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -461,14 +461,16 @@ class LocalComputation(object):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
+ result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
+
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
- result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
result_shape = result_shape.map_leaves(layout_fn)
- compile_options = compile_options or CompileOptions()
- compile_options.result_shape = result_shape
+
+ compile_options = compile_options or CompileOptions()
+ compile_options.result_shape = result_shape
return LocalComputation(
self.c_local_computation.Compile(argument_shapes, compile_options),
is_compiled=True)