aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/randomized_tests.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/randomized_tests.cc')
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc50
1 files changed, 26 insertions, 24 deletions
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 0faf0fd8ed..bddda6f302 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -45,6 +45,8 @@ limitations under the License.
#include <random>
#include <unordered_map>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -61,7 +63,6 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
@@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main()
bool tf_xla_test_use_jit = true;
string LocalDeviceToFullDeviceName(const string& device) {
- return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
+ return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
}
constexpr std::array<DataType, 5> kAllXlaTypes = {
@@ -107,11 +108,12 @@ class OpTestBuilder {
// Sets an attribute.
template <class T>
- OpTestBuilder& Attr(StringPiece attr_name, T&& value);
+ OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
// Overload needed to allow {...} expressions for value.
template <class T>
- OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
+ OpTestBuilder& Attr(absl::string_view attr_name,
+ std::initializer_list<T> value);
// Adds nodes that executes the operator under test on 'device' to 'graphdef'.
// If 'use_jit' is true, marks the operator under test to be compiled by XLA.
@@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
}
template <class T>
-OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
+OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
return *this;
}
template <class T>
-OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name,
+OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
std::initializer_list<T> value) {
Attr<std::initializer_list<T>>(attr_name, std::move(value));
return *this;
@@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
NodeDef* test_def = graphdef->add_node();
*test_def = node_def_;
- test_def->set_name(strings::StrCat(name_prefix, "_op_under_test"));
+ test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
test_def->set_device(device);
AddDefaultsToNodeDef(*op_def, test_def);
if (use_jit) {
@@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
// Build feed and fetch nodes.
for (int i = 0; i < input_types.size(); ++i) {
NodeDef* def = graphdef->add_node();
- string name = strings::StrCat(name_prefix, "_input_", i);
+ string name = absl::StrCat(name_prefix, "_input_", i);
TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
.Device(device)
.Attr("dtype", input_types[i])
@@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix,
for (int i = 0; i < output_types.size(); ++i) {
NodeDef* def = graphdef->add_node();
- string name = strings::StrCat(name_prefix, "_output_", i);
+ string name = absl::StrCat(name_prefix, "_output_", i);
TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
.Device(device)
.Attr("T", output_types[i])
@@ -726,11 +728,11 @@ bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
template <typename T>
string Str(T x) {
- return strings::StrCat(x);
+ return absl::StrCat(x);
}
template <>
string Str<complex64>(complex64 x) {
- return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
+ return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
}
template <typename T>
@@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
auto Ty = y.flat<T>();
for (int i = 0; i < Tx.size(); ++i) {
if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
- return errors::InvalidArgument(strings::StrCat(
- i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
- Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
- "atol = ", atol, " rtol = ", rtol,
- " tol = ", atol + rtol * Abs(Tx(i))));
+ return errors::InvalidArgument(
+ absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
+ " vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
+ "y = ", y.DebugString(), "atol = ", atol,
+ " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
}
}
return Status::OK();
@@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
auto Ty = y.flat<T>();
for (int i = 0; i < Tx.size(); ++i) {
if (Tx(i) != Ty(i)) {
- return errors::InvalidArgument(strings::StrCat(
+ return errors::InvalidArgument(absl::StrCat(
i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i),
". x = ", x.DebugString(), "y = ", y.DebugString()));
}
@@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
double rtol) {
if (a.dtype() != b.dtype()) {
- return errors::InvalidArgument(strings::StrCat(
+ return errors::InvalidArgument(absl::StrCat(
"Tensors have different types: ", DataTypeString(a.dtype()), " and ",
DataTypeString(b.dtype())));
}
if (!a.IsSameSize(b)) {
- return errors::InvalidArgument(strings::StrCat(
- "Tensors have different shapes: ", a.shape().DebugString(), " and ",
- b.shape().DebugString()));
+ return errors::InvalidArgument(
+ absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
+ " and ", b.shape().DebugString()));
}
switch (a.dtype()) {
@@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
}
string cpu_device =
- LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0"));
+ LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0"));
string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
DeviceNameUtils::ParsedName parsed_name;
@@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
std::vector<string> expected_inputs, test_inputs;
std::vector<string> expected_fetches, test_fetches;
Status status = builder.BuildGraph(
- strings::StrCat("test", num_tests_, "_expected"), cpu_device,
+ absl::StrCat("test", num_tests_, "_expected"), cpu_device,
/* use_jit= */ false, &graph, /* test_node_def= */ nullptr,
&expected_inputs, &expected_fetches);
if (!status.ok()) {
@@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
}
NodeDef* node_def;
- status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
+ status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
test_device, tf_xla_test_use_jit, &graph,
&node_def, &test_inputs, &test_fetches);
if (!status.ok()) {