diff options
author | 2018-09-05 13:50:20 -0700 | |
---|---|---|
committer | 2018-09-05 13:54:16 -0700 | |
commit | 11caab3c138d06390344c88a4149f1897e3d780d (patch) | |
tree | 33aac05bfa4fdf6cd81998232268c79000f0d4d0 /tensorflow/compiler/tests | |
parent | c9c8de440213355ea4a4d3577fd068d418678d38 (diff) |
[XLA] Make tensorflow/compiler use absl::{StrCat,string_view,InlinedVector} consistently
StringPiece is an alias for absl::string_view, InlinedVector is aliased to absl::InlinedVector. StrCat is compatible, so swapping it out is safe.
PiperOrigin-RevId: 211691840
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tests/randomized_tests.cc | 50 |
2 files changed, 27 insertions, 24 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 34defe1c7a..050d827a09 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1103,6 +1103,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 0faf0fd8ed..bddda6f302 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,8 @@ limitations under the License. #include <random> #include <unordered_map> +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/core/common_runtime/device.h" @@ -61,7 +63,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { - return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); + return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } constexpr std::array<DataType, 5> kAllXlaTypes = { @@ -107,11 +108,12 @@ class OpTestBuilder { // Sets an attribute. template <class T> - OpTestBuilder& Attr(StringPiece attr_name, T&& value); + OpTestBuilder& Attr(absl::string_view attr_name, T&& value); // Overload needed to allow {...} expressions for value. template <class T> - OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value); + OpTestBuilder& Attr(absl::string_view attr_name, + std::initializer_list<T> value); // Adds nodes that executes the operator under test on 'device' to 'graphdef'. // If 'use_jit' is true, marks the operator under test to be compiled by XLA. @@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, } template <class T> -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) { AddNodeAttr(attr_name, std::forward<T>(value), &node_def_); return *this; } template <class T> -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, std::initializer_list<T> value) { Attr<std::initializer_list<T>>(attr_name, std::move(value)); return *this; @@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, NodeDef* test_def = graphdef->add_node(); *test_def = node_def_; - test_def->set_name(strings::StrCat(name_prefix, "_op_under_test")); + test_def->set_name(absl::StrCat(name_prefix, "_op_under_test")); test_def->set_device(device); AddDefaultsToNodeDef(*op_def, test_def); if (use_jit) { @@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_input_", i); + string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_output_", i); + string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -726,11 +728,11 @@ bool IsClose<complex64>(const complex64& x, const complex64& y, double atol, template <typename T> string Str(T x) { - return strings::StrCat(x); + return absl::StrCat(x); } template <> string Str<complex64>(complex64 x) { - return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); + return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } template <typename T> @@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, auto Ty = y.flat<T>(); for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { - return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", - Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), - "atol = ", atol, " rtol = ", rtol, - " tol = ", atol + rtol * Abs(Tx(i)))); + return errors::InvalidArgument( + absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)), + " vs. ", Str(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString(), "atol = ", atol, + " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { auto Ty = y.flat<T>(); for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), ". x = ", x.DebugString(), "y = ", y.DebugString())); } @@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, double rtol) { if (a.dtype() != b.dtype()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Tensors have different types: ", DataTypeString(a.dtype()), " and ", DataTypeString(b.dtype()))); } if (!a.IsSameSize(b)) { - return errors::InvalidArgument(strings::StrCat( - "Tensors have different shapes: ", a.shape().DebugString(), " and ", - b.shape().DebugString())); + return errors::InvalidArgument( + absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(), + " and ", b.shape().DebugString())); } switch (a.dtype()) { @@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } string cpu_device = - LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); + LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector<string> expected_inputs, test_inputs; std::vector<string> expected_fetches, test_fetches; Status status = builder.BuildGraph( - strings::StrCat("test", num_tests_, "_expected"), cpu_device, + absl::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, &expected_inputs, &expected_fetches); if (!status.ok()) { @@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } NodeDef* node_def; - status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"), test_device, tf_xla_test_use_jit, &graph, &node_def, &test_inputs, &test_fetches); if (!status.ok()) { |