diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.h')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 80593eaca5..acc64d99d3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -242,7 +244,8 @@ class XlaCompiler { std::shared_ptr<xla::XlaComputation> computation; }; - typedef std::function<TensorShape(const TensorShape&, DataType)> + typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&, + DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. |