diff options
author | 2017-05-26 07:59:40 +0100 | |
---|---|---|
committer | 2017-06-13 18:41:07 -0700 | |
commit | d59b64fa4c5220eb50e32703b39832f3ed0b3726 (patch) | |
tree | d2704503b65c2cab09bf5f16e7011b53318dd614 /tensorflow | |
parent | 43d8475a551723d905cf7bd122825fa58556f06e (diff) |
Adjust API to match current public repo
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/plugin/executor/compiler.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/compiler.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/device.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/executable.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/executable.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/plugin/executor/transfer_manager.cc | 4 |
6 files changed, 22 insertions, 10 deletions
diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc index a3591a8a47..893ff152f0 100644 --- a/tensorflow/compiler/plugin/executor/compiler.cc +++ b/tensorflow/compiler/plugin/executor/compiler.cc @@ -103,14 +103,16 @@ ExecutorCompiler::CompileAheadOfTime( "AOT compilation not supported on Executor"); } -int64 ExecutorCompiler::ShapeSizeBytes(const Shape& shape) const { - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); -} - se::Platform::Id ExecutorCompiler::PlatformId() const { return sep::kExecutorPlatformId; } +HloCostAnalysis::ShapeSizeFunction +ExecutorCompiler::ShapeSizeBytesFunction() const { + return ExecutorExecutable::ShapeSizeBytes; +} + + } // namespace executorplugin } // namespace xla diff --git a/tensorflow/compiler/plugin/executor/compiler.h b/tensorflow/compiler/plugin/executor/compiler.h index e66f0a4ea3..8fe591c8ab 100644 --- a/tensorflow/compiler/plugin/executor/compiler.h +++ b/tensorflow/compiler/plugin/executor/compiler.h @@ -48,7 +48,7 @@ class ExecutorCompiler : public Compiler { std::vector<std::unique_ptr<HloModule>> module, HloDumper dump_hlo, const AotCompilationOptions& options) override; - int64 ShapeSizeBytes(const Shape& shape) const override; + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; perftools::gputools::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/plugin/executor/device.cc b/tensorflow/compiler/plugin/executor/device.cc index 10e13226f9..54b49d789d 100644 --- a/tensorflow/compiler/plugin/executor/device.cc +++ b/tensorflow/compiler/plugin/executor/device.cc @@ -27,7 +27,7 @@ const char* const DEVICE_XLA_EXEC = "XLA_EXEC"; const char* const DEVICE_EXEC_XLA_JIT = "XLA_EXEC_JIT"; constexpr std::array<DataType, 5> kExecAllTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL}}; + {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; class XlaExaDeviceFactory : public DeviceFactory { public: @@ -50,7 +50,7 @@ Status XlaExaDeviceFactory::CreateDevices(const SessionOptions& options, return Status::OK(); } -REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 210); +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_EXEC, XlaExaDeviceFactory, 110); // Kernel registrations diff --git a/tensorflow/compiler/plugin/executor/executable.cc b/tensorflow/compiler/plugin/executor/executable.cc index 0caee906ac..480d9d98bb 100644 --- a/tensorflow/compiler/plugin/executor/executable.cc +++ b/tensorflow/compiler/plugin/executor/executable.cc @@ -28,7 +28,7 @@ namespace xla { namespace executorplugin { ExecutorExecutable::ExecutorExecutable(std::unique_ptr<HloModule> hlo_module) - : Executable(std::move(hlo_module)) {} + : Executable(std::move(hlo_module), ShapeSizeBytes) {} ExecutorExecutable::~ExecutorExecutable() {} @@ -135,5 +135,13 @@ StatusOr<se::DeviceMemoryBase> ExecutorExecutable::ExecuteAsyncOnStream( "ExecuteAsyncOnStream is not yet supported on Executor."); } +/*static*/ int64 ExecutorExecutable::ShapeSizeBytes(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape)) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); +} + + } // namespace executorplugin } // namespace xla diff --git a/tensorflow/compiler/plugin/executor/executable.h b/tensorflow/compiler/plugin/executor/executable.h index 4278fa219f..1238185449 100644 --- a/tensorflow/compiler/plugin/executor/executable.h +++ b/tensorflow/compiler/plugin/executor/executable.h @@ -56,6 +56,8 @@ class ExecutorExecutable : public Executable { tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> arguments) override; + static int64 ShapeSizeBytes(const Shape& shape); + private: TF_DISALLOW_COPY_AND_ASSIGN(ExecutorExecutable); }; diff --git a/tensorflow/compiler/plugin/executor/transfer_manager.cc b/tensorflow/compiler/plugin/executor/transfer_manager.cc index e7219fd761..b59d20a779 100644 --- a/tensorflow/compiler/plugin/executor/transfer_manager.cc +++ b/tensorflow/compiler/plugin/executor/transfer_manager.cc @@ -170,8 +170,8 @@ int64 ExecutorTransferManager::GetByteSizeRequirement(const Shape& shape) { } // namespace executorplugin } // namespace xla -static xla::TransferManager* CreateExecutorTransferManager() { - return new xla::executorplugin::ExecutorTransferManager(); +static std::unique_ptr<xla::TransferManager> CreateExecutorTransferManager() { + return xla::MakeUnique<xla::executorplugin::ExecutorTransferManager>(); } static bool InitModule() { |