Skip to content

Commit

Permalink
MLIR emitter: Disable TanH and Clamp/min/max.
Browse files Browse the repository at this point in the history
TanH: The default lowering is too inaccurate.
The others don't handle NaNs correctly.
PiperOrigin-RevId: 609328010
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Feb 22, 2024
1 parent 7756873 commit a57309f
Showing 1 changed file with 5 additions and 16 deletions.
Expand Up @@ -142,16 +142,11 @@ static auto& kUnsupportedOps =
HloOpcode::kCall};

static auto& kUnimplementedOps = *new absl::flat_hash_set<HloOpcode>{
HloOpcode::kConvolution,
HloOpcode::kDot,
HloOpcode::kDynamicUpdateSlice,
HloOpcode::kMap,
HloOpcode::kReduceWindow,
// Has a custom approximation in XLA:
HloOpcode::kErf,
};

static auto& kF32SupportedOps = *new absl::flat_hash_set<HloOpcode>{
HloOpcode::kConvolution, HloOpcode::kDot, HloOpcode::kDynamicUpdateSlice,
HloOpcode::kMap, HloOpcode::kReduceWindow,
// Custom approximations in XLA:
HloOpcode::kErf, HloOpcode::kTanh,
// Incorrect NaN handling:
HloOpcode::kMaximum, HloOpcode::kMinimum, HloOpcode::kClamp};

bool IsUnsupportedConstant(const HloInstruction* instr) {
Expand Down Expand Up @@ -776,12 +771,6 @@ bool IsHloOpSupported(const HloInstruction* instr,
return false;
}

// TODO(jreiffers): Fix the F64 lowering for these ops.
if (kF32SupportedOps.contains(instr->opcode()) &&
instr->shape().element_type() == F64) {
return false;
}

return !(kUnsupportedOps.contains(instr->opcode()) ||
kUnimplementedOps.contains(instr->opcode()) ||
IsUnsupportedConstant(instr) || IsUnsupportedTuple(instr) ||
Expand Down

0 comments on commit a57309f

Please sign in to comment.