aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 21:08:05 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 21:11:33 -0800
commit6eec9c2ea33f3b86012cb0ea2aeb9e49e65bc716 (patch)
treed02d0e5149055a01318278876967ef3d04796de6
parent1ec61fafe13e5edce6e45d5a67e960efb9df618a (diff)
[XLA] Hlo parser: support rng and reduce-precision. Also simplify the lexer by regarding several things as identifier.
PiperOrigin-RevId: 177548483
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc30
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h5
-rw-r--r--tensorflow/compiler/xla/tools/parser/BUILD2
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.cc32
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_lexer.h14
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc81
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc25
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_token.h6
8 files changed, 149 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index b4bac18bcd..45825c7c76 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2060,6 +2060,14 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
extra.push_back(
StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
}
+ if (opcode() == HloOpcode::kRng) {
+ extra.push_back(
+ StrCat("distribution=", RandomDistributionToString(distribution_)));
+ }
+ if (opcode() == HloOpcode::kReducePrecision) {
+ extra.push_back(StrCat("exponent_bits=", exponent_bits_));
+ extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
+ }
return extra;
}
@@ -3029,6 +3037,28 @@ string OpMetadataToString(const OpMetadata& metadata) {
return Join(result, " ");
}
+string RandomDistributionToString(const RandomDistribution& distribution) {
+ return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
+}
+
+StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
+ static std::unordered_map<string, RandomDistribution>* map = [] {
+ static auto* map = new std::unordered_map<string, RandomDistribution>;
+ for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
+ if (RandomDistribution_IsValid(i)) {
+ auto value = static_cast<RandomDistribution>(i);
+ (*map)[RandomDistributionToString(value)] = value;
+ }
+ }
+ return map;
+ }();
+ auto found = map->find(tensorflow::str_util::Lowercase(name));
+ if (found == map->end()) {
+ return InvalidArgument("Unknown distribution");
+ }
+ return found->second;
+}
+
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 768c027a42..088902e2a7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1285,9 +1285,12 @@ string ToString(HloInstruction::FusionKind kind);
StatusOr<HloInstruction::FusionKind> StringToFusionKind(
const string& kind_name);
-// Custom stringification functions for protos that live inside HloInstruction.
+// Custom (de)stringification functions for protos that live inside
+// HloInstruction.
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
+string RandomDistributionToString(const RandomDistribution& distribution);
+StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD
index ce936af6c3..97aacf6b39 100644
--- a/tensorflow/compiler/xla/tools/parser/BUILD
+++ b/tensorflow/compiler/xla/tools/parser/BUILD
@@ -34,9 +34,9 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
],
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
index 56744440db..04247594ed 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <unordered_map>
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
@@ -153,15 +152,15 @@ TokKind HloLexer::LexToken() {
}
}
-// Lex a shape, name, keyword, opcode, attribute name, or the dim labels
-// pattern.
+// Lex a shape, name, keyword, attribute name, the dim labels pattern, and
+// other identifiers.
//
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
-// opcode ::= add, greater-than, ...
// attribute_name ::= condition, body, dimensions, ...
// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
+// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
TokKind HloLexer::LexIdentifier() {
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
@@ -220,20 +219,6 @@ TokKind HloLexer::LexIdentifier() {
#undef KEYWORD
- // See if this is an opcode.
- auto opcode = StringToHloOpcode(identifier.ToString());
- if (opcode.ok()) {
- opcode_val_ = opcode.ValueOrDie();
- return TokKind::kOpcode;
- }
-
- // See if this is an fusion kind.
- auto kind = xla::StringToFusionKind(identifier.ToString());
- if (kind.ok()) {
- fusion_kind_val_ = kind.ValueOrDie();
- return TokKind::kFusionKind;
- }
-
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 dim_labels_pattern = {
@@ -244,8 +229,9 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kDimLabels;
}
}
- current_ptr_ = token_start_ + 1;
- return TokKind::kError;
+
+ str_val_ = identifier.ToString();
+ return TokKind::kIdent;
}
// Lex names after a % character.
@@ -428,14 +414,12 @@ string TokKindToString(TokKind kind) {
return "kDxD";
case TokKind::kPad:
return "kPad";
+ case TokKind::kIdent:
+ return "kIdent";
case TokKind::kString:
return "kString";
case TokKind::kShape:
return "kShape";
- case TokKind::kOpcode:
- return "kOpcode";
- case TokKind::kFusionKind:
- return "kFusionKind";
case TokKind::kInt:
return "kInt";
case TokKind::kDecimal:
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
index 5c9d1bf391..9daf6a11d3 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
@@ -18,9 +18,8 @@ limitations under the License.
#include <string>
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
+#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
@@ -48,6 +47,7 @@ class HloLexer {
case TokKind::kDxD:
case TokKind::kPad:
case TokKind::kString:
+ case TokKind::kIdent:
return str_val_;
default:
LOG(FATAL) << "This token does not have string value";
@@ -57,14 +57,6 @@ class HloLexer {
CHECK(GetKind() == TokKind::kShape);
return shape_val_;
}
- HloOpcode GetOpcodeVal() const {
- CHECK(GetKind() == TokKind::kOpcode);
- return opcode_val_;
- }
- HloInstruction::FusionKind GetFusionKindVal() const {
- CHECK(GetKind() == TokKind::kFusionKind);
- return fusion_kind_val_;
- }
int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
@@ -114,8 +106,6 @@ class HloLexer {
TokKind current_kind_;
string str_val_;
Shape shape_val_;
- HloOpcode opcode_val_;
- HloInstruction::FusionKind fusion_kind_val_;
int64 int64_val_;
double decimal_val_;
};
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 47979ec6f3..ddc1e69951 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -104,6 +105,7 @@ class HloParser {
kPaddingConfig,
kMetadata,
kFusionKind,
+ kDistribution,
};
struct AttrConfig {
@@ -174,6 +176,7 @@ class HloParser {
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
+ bool ParseRandomDistribution(RandomDistribution* result);
bool ParseInt64(int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -816,10 +819,36 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, operands[0], config ? *config : ""));
break;
}
+ case HloOpcode::kRng: {
+ optional<RandomDistribution> distribution;
+ attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
+ &distribution};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateRng(shape, *distribution, operands));
+ break;
+ }
+ case HloOpcode::kReducePrecision: {
+ optional<int64> exponent_bits;
+ optional<int64> mantissa_bits;
+ attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
+ &exponent_bits};
+ attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
+ &mantissa_bits};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateReducePrecision(
+ shape, operands[0], static_cast<int>(*exponent_bits),
+ static_cast<int>(*mantissa_bits)));
+ break;
+ }
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
- case HloOpcode::kReducePrecision:
- case HloOpcode::kRng:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
@@ -1548,6 +1577,15 @@ bool HloParser::ParseAttributeHelper(
static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
return true;
}
+ case AttrTy::kDistribution: {
+ RandomDistribution result;
+ if (!ParseRandomDistribution(&result)) {
+ return false;
+ }
+ static_cast<optional<RandomDistribution>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
@@ -2024,20 +2062,51 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) {
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
- if (lexer_.GetKind() != TokKind::kOpcode) {
+ if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects opcode");
}
- *result = lexer_.GetOpcodeVal();
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToHloOpcode(val);
+ if (!status_or_result.ok()) {
+ return TokenError(
+ Printf("expects opcode but sees: %s, error: %s", val.c_str(),
+ status_or_result.status().error_message().c_str()));
+ }
+ *result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(1) << "ParseFusionKind";
- if (lexer_.GetKind() != TokKind::kFusionKind) {
+ if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fusion kind");
}
- *result = lexer_.GetFusionKindVal();
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToFusionKind(val);
+ if (!status_or_result.ok()) {
+ return TokenError(
+ Printf("expects fusion kind but sees: %s, error: %s", val.c_str(),
+ status_or_result.status().error_message().c_str()));
+ }
+ *result = status_or_result.ValueOrDie();
+ lexer_.Lex();
+ return true;
+}
+
+bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
+ VLOG(1) << "ParseRandomDistribution";
+ if (lexer_.GetKind() != TokKind::kIdent) {
+ return TokenError("expects random distribution");
+ }
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToRandomDistribution(val);
+ if (!status_or_result.ok()) {
+ return TokenError(
+ Printf("expects random distribution but sees: %s, error: %s",
+ val.c_str(), status_or_result.status().error_message().c_str()));
+ }
+ *result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 90cdb87a1e..69d48d65bc 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -655,6 +655,31 @@ ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) {
}
)"
+},
+// Rng
+{
+"Rng",
+R"(HloModule rng_module:
+
+ENTRY %Rng () -> f32[8] {
+ %constant = f32[] constant(0)
+ %constant.1 = f32[] constant(1)
+ ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform
+}
+
+)"
+},
+// Reduce precision
+{
+"ReducePrevison",
+R"(HloModule reduce_precision:
+
+ENTRY %ReducePrecision () -> f32[1] {
+ %constant = f32[1]{0} constant({3.14159})
+ ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10
+}
+
+)"
}
});
// clang-format on
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h
index 07e48804d0..7928bee5c2 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h
@@ -18,6 +18,9 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+
namespace xla {
namespace tools {
@@ -60,10 +63,9 @@ enum class TokKind {
kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
kDxD, // [0-9]+(x[0-9]+)+
kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
+ kIdent, // other identifiers
kString, // "abcd\"\n"
kShape, // f32[2,3]{1,0}
- kOpcode, // add
- kFusionKind, // kLoop, kOutput, ...
kInt, // 42
kDecimal, // 4.2
};