diff options
author | 2018-07-10 13:38:20 -0700 | |
---|---|---|
committer | 2018-07-10 13:42:10 -0700 | |
commit | c4ac9855539dc228881707c69a9ef2fe703dadd4 (patch) | |
tree | 5a9152abc7ce0a1353ec370988a91b55ae6a5f3b | |
parent | a67c15a3291b96b54c6457bde832e1221a96381a (diff) |
[XLA] tweak LocalComputation.Compile in XLA Python client so that layout_fn is optional
PiperOrigin-RevId: 204003372
-rw-r--r-- | tensorflow/compiler/xla/python/xla_client.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 27aee634ba..e2b6eaa096 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -461,14 +461,16 @@ class LocalComputation(object): if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] - result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) result_shape = result_shape.map_leaves(layout_fn) - compile_options = compile_options or CompileOptions() - compile_options.result_shape = result_shape + + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) |