/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/tensor_testutil.h" #include namespace tensorflow { namespace test { template void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { const T* Tx = x.flat().data(); const T* Ty = y.flat().data(); const auto size = x.NumElements(); // Tolerance's type (RealType) can be different from T. // For example, if T = std::complex, then RealType = float. // Did not use std::numeric_limits because // 1) It returns 0 for Eigen::half. // 2) It doesn't support T=std::complex. // (Would have to write a templated struct to handle this.) typedef decltype(Eigen::NumTraits::epsilon()) RealType; const RealType kSlackFactor = static_cast(5.0); const RealType kDefaultTol = kSlackFactor * Eigen::NumTraits::epsilon(); const RealType typed_atol = (atol < 0) ? kDefaultTol : static_cast(atol); const RealType typed_rtol = (rtol < 0) ? kDefaultTol : static_cast(rtol); ASSERT_GE(typed_atol, static_cast(0.0)) << "typed_atol is negative: " << typed_atol; ASSERT_GE(typed_rtol, static_cast(0.0)) << "typed_rtol is negative: " << typed_rtol; for (int i = 0; i < size; ++i) { EXPECT_TRUE( internal::Helper::IsClose(Tx[i], Ty[i], typed_atol, typed_rtol)) << "index = " << i << " x = " << Tx[i] << " y = " << Ty[i] << " typed_atol = " << typed_atol << " typed_rtol = " << typed_rtol; } } void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { internal::AssertSameTypeDims(x, y); switch (x.dtype()) { case DT_HALF: ExpectClose(x, y, atol, rtol); break; case DT_FLOAT: ExpectClose(x, y, atol, rtol); break; case DT_DOUBLE: ExpectClose(x, y, atol, rtol); break; case DT_COMPLEX64: ExpectClose(x, y, atol, rtol); break; case DT_COMPLEX128: ExpectClose(x, y, atol, rtol); break; default: LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype()); } } } // end namespace test } // end namespace tensorflow