aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2017-06-28 17:55:23 +0000
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2017-06-28 17:55:23 +0000
commit53725c10b80dabd2a536f66e854c50f892496946 (patch)
treebc49b7e7f3372c92a6c8ea07068ed54ef38ab29d /unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
parentb8e805497e446e7159f231238b4a8fd22fe70749 (diff)
Merged in mehdi_goli/opencl/DataDependancy (pull request PR-10)
DataDependancy * Wrapping data type to the pointer class for sycl in non-terminal nodes; not having that breaks Tensorflow Conv2d code. * Applying Ronnan's Comments. * Applying benoit's comments
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h4
1 files changed, 3 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
index 2c7ba961c..363df876c 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
@@ -37,6 +37,8 @@ struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >
static const int NumDimensions = traits<LhsXprType>::NumDimensions;
static const int Layout = traits<LhsXprType>::Layout;
enum { Flags = 0 };
+ typedef typename conditional<::Eigen::internal::Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
+ typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
};
template<typename Axis, typename LhsXprType, typename RhsXprType>
@@ -275,7 +277,7 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy
TensorOpCost(0, 0, compute_cost);
}
- EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
+ EIGEN_DEVICE_FUNC typename Eigen::internal::traits<XprType>::PointerType data() const { return NULL; }
/// required by sycl in order to extract the accessor
const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; }
/// required by sycl in order to extract the accessor