aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar DavidNorman <davidn@graphcore.ai>2017-05-26 07:59:40 +0100
committerGravatar Martin Wicke <martin.wicke@gmail.com>2017-06-13 18:41:07 -0700
commitd59b64fa4c5220eb50e32703b39832f3ed0b3726 (patch)
treed2704503b65c2cab09bf5f16e7011b53318dd614 /tensorflow
parent43d8475a551723d905cf7bd122825fa58556f06e (diff)
Adjust API to match current public repo
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/plugin/executor/compiler.cc10
-rw-r--r--tensorflow/compiler/plugin/executor/compiler.h2
-rw-r--r--tensorflow/compiler/plugin/executor/device.cc4
-rw-r--r--tensorflow/compiler/plugin/executor/executable.cc10
-rw-r--r--tensorflow/compiler/plugin/executor/executable.h2
-rw-r--r--tensorflow/compiler/plugin/executor/transfer_manager.cc4
6 files changed, 22 insertions, 10 deletions
diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc
index a3591a8a47..893ff152f0 100644
--- a/tensorflow/compiler/plugin/executor/compiler.cc
+++ b/tensorflow/compiler/plugin/executor/compiler.cc
@@ -103,14 +103,16 @@ ExecutorCompiler::CompileAheadOfTime(
"AOT compilation not supported on Executor");
}
-int64 ExecutorCompiler::ShapeSizeBytes(const Shape& shape) const {
- return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
-}
-
se::Platform::Id ExecutorCompiler::PlatformId() const {
return sep::kExecutorPlatformId;
}
+HloCostAnalysis::ShapeSizeFunction
+ExecutorCompiler::ShapeSizeBytesFunction() const {
+ return ExecutorExecutable::ShapeSizeBytes;
+}
+
+
} // namespace executorplugin
} // namespace xla
diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h
index e66f0a4ea3..8fe591c8ab 100644
--- a/tensorflow/compiler/plugin/executor/compiler.h
+++ b/tensorflow/compiler/plugin/executor/compiler.h
@@ -48,7 +48,7 @@ class ExecutorCompiler : public Compiler {
std::vector<std::unique_ptr<HloModule>> module,
HloDumper dump_hlo, const AotCompilationOptions& options) override;
- int64 ShapeSizeBytes(const Shape& shape) const override;
+ HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
perftools::gputools::Platform::Id PlatformId() const override;
diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc
index 10e13226f9..54b49d789d 100644
--- a/tensorflow/compiler/plugin/executor/device.cc
+++ b/tensorflow/compiler/plugin/executor/device.cc
@@ -27,7 +27,7 @@ const char* const DEVICE_XLA_EXEC = "XLA_EXEC";
const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT";
constexpr std::array<DataType, 5> kExecAllTypes = {
- {DT_INT32, DT_FLOAT, DT_BOOL}};
+ {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}};
class XlaExaDeviceFactory : public DeviceFactory {
public:
@@ -50,7 +50,7 @@ Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options,
return Status::OK();
}
-REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 210);
+REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110);
// Kernel registrations
diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc
index 0caee906ac..480d9d98bb 100644
--- a/tensorflow/compiler/plugin/executor/executable.cc
+++ b/tensorflow/compiler/plugin/executor/executable.cc
@@ -28,7 +28,7 @@ namespace xla {
namespace executorplugin {
ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module)
- : Executable(std::move(hlo_module)) {}
+ : Executable(std::move(hlo_module), ShapeSizeBytes) {}
ExecutorExecutable::~ExecutorExecutable() {}
@@ -135,5 +135,13 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream(
"ExecuteAsyncOnStream is not yet supported on Executor.");
}
+/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) {
+ if (ShapeUtil::IsOpaque(shape)) {
+ return sizeof(void*);
+ }
+ return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
+}
+
+
} // namespace executorplugin
} // namespace xla
diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/plugin/executor/executable.h
index 4278fa219f..1238185449 100644
--- a/tensorflow/compiler/plugin/executor/executable.h
+++ b/tensorflow/compiler/plugin/executor/executable.h
@@ -56,6 +56,8 @@ class ExecutorExecutable : public Executable {
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments) override;
+ static int64 ShapeSizeBytes(const Shape& shape);
+
private:
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable);
};
diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc
index e7219fd761..b59d20a779 100644
--- a/tensorflow/compiler/plugin/executor/transfer_manager.cc
+++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc
@@ -170,8 +170,8 @@ int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) {
} // namespace executorplugin
} // namespace xla
-static xla::TransferManager* CreateExecutorTransferManager() {
- return new xla::executorplugin::ExecutorTransferManager();
+static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() {
+ return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>();
}
static bool InitModule() {