aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device_context.cc')
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_device_context.cc256
1 files changed, 128 insertions, 128 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
index a6be9195d4..1c868f5606 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
+++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h"
namespace tensorflow {
@@ -31,68 +31,68 @@ void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor,
const void *src_ptr = DMAHelper::base(cpu_tensor);
void *dst_ptr = DMAHelper::base(device_tensor);
switch (cpu_tensor->dtype()) {
- case DT_FLOAT:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
- total_bytes);
- break;
- case DT_DOUBLE:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<double *>(dst_ptr), static_cast<const double *>(src_ptr),
- total_bytes);
- break;
- case DT_INT32:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT64:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
- total_bytes);
- break;
- case DT_HALF:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<Eigen::half *>(dst_ptr),
- static_cast<const Eigen::half *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX64:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<std::complex<float> *>(dst_ptr),
- static_cast<const std::complex<float> *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX128:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<std::complex<double> *>(dst_ptr),
- static_cast<const std::complex<double> *>(src_ptr), total_bytes);
- break;
- case DT_INT8:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT16:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT8:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT16:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<uint16 *>(dst_ptr), static_cast<const uint16 *>(src_ptr),
- total_bytes);
- break;
- case DT_BOOL:
- device->eigen_sycl_device()->memcpyHostToDevice(
- static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
- total_bytes);
- break;
- default:
- assert(false && "unsupported type");
+ case DT_FLOAT:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_DOUBLE:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<double *>(dst_ptr),
+ static_cast<const double *>(src_ptr), total_bytes);
+ break;
+ case DT_INT32:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT64:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_HALF:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<Eigen::half *>(dst_ptr),
+ static_cast<const Eigen::half *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX64:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<std::complex<float> *>(dst_ptr),
+ static_cast<const std::complex<float> *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX128:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<std::complex<double> *>(dst_ptr),
+ static_cast<const std::complex<double> *>(src_ptr), total_bytes);
+ break;
+ case DT_INT8:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT16:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT8:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT16:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<uint16 *>(dst_ptr),
+ static_cast<const uint16 *>(src_ptr), total_bytes);
+ break;
+ case DT_BOOL:
+ device->eigen_sycl_device()->memcpyHostToDevice(
+ static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
+ total_bytes);
+ break;
+ default:
+ assert(false && "unsupported type");
}
}
device->eigen_sycl_device()->synchronize();
@@ -106,71 +106,71 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor,
StatusCallback done) {
const int64 total_bytes = device_tensor->TotalBytes();
if (total_bytes > 0) {
- const void* src_ptr = DMAHelper::base(device_tensor);
- void* dst_ptr = DMAHelper::base(cpu_tensor);
+ const void *src_ptr = DMAHelper::base(device_tensor);
+ void *dst_ptr = DMAHelper::base(cpu_tensor);
switch (device_tensor->dtype()) {
- case DT_FLOAT:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
- total_bytes);
- break;
- case DT_DOUBLE:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<double *>(dst_ptr), static_cast<const double *>(src_ptr),
- total_bytes);
- break;
- case DT_INT32:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT64:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
- total_bytes);
- break;
- case DT_HALF:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<Eigen::half *>(dst_ptr),
- static_cast<const Eigen::half *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX64:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<std::complex<float> *>(dst_ptr),
- static_cast<const std::complex<float> *>(src_ptr), total_bytes);
- break;
- case DT_COMPLEX128:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<std::complex<double> *>(dst_ptr),
- static_cast<const std::complex<double> *>(src_ptr), total_bytes);
- break;
- case DT_INT8:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
- total_bytes);
- break;
- case DT_INT16:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT8:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
- total_bytes);
- break;
- case DT_UINT16:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<uint16 *>(dst_ptr), static_cast<const uint16 *>(src_ptr),
- total_bytes);
- break;
- case DT_BOOL:
- device->eigen_sycl_device()->memcpyDeviceToHost(
- static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
- total_bytes);
- break;
- default:
- assert(false && "unsupported type");
+ case DT_FLOAT:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_DOUBLE:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<double *>(dst_ptr),
+ static_cast<const double *>(src_ptr), total_bytes);
+ break;
+ case DT_INT32:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT64:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_HALF:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<Eigen::half *>(dst_ptr),
+ static_cast<const Eigen::half *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX64:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<std::complex<float> *>(dst_ptr),
+ static_cast<const std::complex<float> *>(src_ptr), total_bytes);
+ break;
+ case DT_COMPLEX128:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<std::complex<double> *>(dst_ptr),
+ static_cast<const std::complex<double> *>(src_ptr), total_bytes);
+ break;
+ case DT_INT8:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_INT16:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT8:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr),
+ total_bytes);
+ break;
+ case DT_UINT16:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<uint16 *>(dst_ptr),
+ static_cast<const uint16 *>(src_ptr), total_bytes);
+ break;
+ case DT_BOOL:
+ device->eigen_sycl_device()->memcpyDeviceToHost(
+ static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr),
+ total_bytes);
+ break;
+ default:
+ assert(false && "unsupported type");
}
}
device->eigen_sycl_device()->synchronize();
@@ -178,4 +178,4 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor,
}
} // namespace tensorflow
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL