diff options
Diffstat (limited to 'tensorflow/core/common_runtime/sycl/sycl_device_context.cc')
-rw-r--r-- | tensorflow/core/common_runtime/sycl/sycl_device_context.cc | 256 |
1 files changed, 128 insertions, 128 deletions
diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc index a6be9195d4..1c868f5606 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc @@ -17,8 +17,8 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h" namespace tensorflow { @@ -31,68 +31,68 @@ void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor, const void *src_ptr = DMAHelper::base(cpu_tensor); void *dst_ptr = DMAHelper::base(device_tensor); switch (cpu_tensor->dtype()) { - case DT_FLOAT: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr), - total_bytes); - break; - case DT_DOUBLE: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<double *>(dst_ptr), static_cast<const double *>(src_ptr), - total_bytes); - break; - case DT_INT32: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr), - total_bytes); - break; - case DT_INT64: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr), - total_bytes); - break; - case DT_HALF: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<Eigen::half *>(dst_ptr), - static_cast<const Eigen::half *>(src_ptr), total_bytes); - break; - case DT_COMPLEX64: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<std::complex<float> *>(dst_ptr), - static_cast<const std::complex<float> *>(src_ptr), total_bytes); - break; - case DT_COMPLEX128: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<std::complex<double> *>(dst_ptr), - static_cast<const std::complex<double> *>(src_ptr), total_bytes); - break; - case DT_INT8: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr), - total_bytes); - break; - case DT_INT16: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr), - total_bytes); - break; - case DT_UINT8: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr), - total_bytes); - break; - case DT_UINT16: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<uint16 *>(dst_ptr), static_cast<const uint16 *>(src_ptr), - total_bytes); - break; - case DT_BOOL: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr), - total_bytes); - break; - default: - assert(false && "unsupported type"); + case DT_FLOAT: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr), + total_bytes); + break; + case DT_DOUBLE: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<double *>(dst_ptr), + static_cast<const double *>(src_ptr), total_bytes); + break; + case DT_INT32: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr), + total_bytes); + break; + case DT_INT64: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr), + total_bytes); + break; + case DT_HALF: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<Eigen::half *>(dst_ptr), + static_cast<const Eigen::half *>(src_ptr), total_bytes); + break; + case DT_COMPLEX64: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<std::complex<float> *>(dst_ptr), + static_cast<const std::complex<float> *>(src_ptr), total_bytes); + break; + case DT_COMPLEX128: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<std::complex<double> *>(dst_ptr), + static_cast<const std::complex<double> *>(src_ptr), total_bytes); + break; + case DT_INT8: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr), + total_bytes); + break; + case DT_INT16: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr), + total_bytes); + break; + case DT_UINT8: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr), + total_bytes); + break; + case DT_UINT16: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<uint16 *>(dst_ptr), + static_cast<const uint16 *>(src_ptr), total_bytes); + break; + case DT_BOOL: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr), + total_bytes); + break; + default: + assert(false && "unsupported type"); } } device->eigen_sycl_device()->synchronize(); @@ -106,71 +106,71 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor, StatusCallback done) { const int64 total_bytes = device_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = DMAHelper::base(device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); + const void *src_ptr = DMAHelper::base(device_tensor); + void *dst_ptr = DMAHelper::base(cpu_tensor); switch (device_tensor->dtype()) { - case DT_FLOAT: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr), - total_bytes); - break; - case DT_DOUBLE: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<double *>(dst_ptr), static_cast<const double *>(src_ptr), - total_bytes); - break; - case DT_INT32: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr), - total_bytes); - break; - case DT_INT64: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr), - total_bytes); - break; - case DT_HALF: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<Eigen::half *>(dst_ptr), - static_cast<const Eigen::half *>(src_ptr), total_bytes); - break; - case DT_COMPLEX64: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<std::complex<float> *>(dst_ptr), - static_cast<const std::complex<float> *>(src_ptr), total_bytes); - break; - case DT_COMPLEX128: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<std::complex<double> *>(dst_ptr), - static_cast<const std::complex<double> *>(src_ptr), total_bytes); - break; - case DT_INT8: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr), - total_bytes); - break; - case DT_INT16: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr), - total_bytes); - break; - case DT_UINT8: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr), - total_bytes); - break; - case DT_UINT16: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<uint16 *>(dst_ptr), static_cast<const uint16 *>(src_ptr), - total_bytes); - break; - case DT_BOOL: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr), - total_bytes); - break; - default: - assert(false && "unsupported type"); + case DT_FLOAT: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<float *>(dst_ptr), static_cast<const float *>(src_ptr), + total_bytes); + break; + case DT_DOUBLE: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<double *>(dst_ptr), + static_cast<const double *>(src_ptr), total_bytes); + break; + case DT_INT32: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<int32 *>(dst_ptr), static_cast<const int32 *>(src_ptr), + total_bytes); + break; + case DT_INT64: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<int64 *>(dst_ptr), static_cast<const int64 *>(src_ptr), + total_bytes); + break; + case DT_HALF: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<Eigen::half *>(dst_ptr), + static_cast<const Eigen::half *>(src_ptr), total_bytes); + break; + case DT_COMPLEX64: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<std::complex<float> *>(dst_ptr), + static_cast<const std::complex<float> *>(src_ptr), total_bytes); + break; + case DT_COMPLEX128: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<std::complex<double> *>(dst_ptr), + static_cast<const std::complex<double> *>(src_ptr), total_bytes); + break; + case DT_INT8: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<int8 *>(dst_ptr), static_cast<const int8 *>(src_ptr), + total_bytes); + break; + case DT_INT16: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<int16 *>(dst_ptr), static_cast<const int16 *>(src_ptr), + total_bytes); + break; + case DT_UINT8: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<uint8 *>(dst_ptr), static_cast<const uint8 *>(src_ptr), + total_bytes); + break; + case DT_UINT16: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<uint16 *>(dst_ptr), + static_cast<const uint16 *>(src_ptr), total_bytes); + break; + case DT_BOOL: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast<bool *>(dst_ptr), static_cast<const bool *>(src_ptr), + total_bytes); + break; + default: + assert(false && "unsupported type"); } } device->eigen_sycl_device()->synchronize(); @@ -178,4 +178,4 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor, } } // namespace tensorflow -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL |