aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_compiler.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiler.h')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 80593eaca5..acc64d99d3 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -242,7 +244,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
- typedef std::function<TensorShape(const TensorShape&, DataType)>
+ typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&,
+ DataType)>
ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. It must be set by the caller.