Skip to content

Commit

Permalink
fixed rounding of values in cli output
Browse files Browse the repository at this point in the history
  • Loading branch information
lsickert committed Nov 22, 2022
1 parent a2a2021 commit 7b3c4fd
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions inseq/data/viz.py
Expand Up @@ -236,7 +236,7 @@ def get_saliency_heatmap_rich(
color = Color.from_rgb(*input_colors[row_index][col_index])
score = ""
if not np.isnan(scores[row_index][col_index]):
score = round(scores[row_index][col_index], 2)
score = round(float(scores[row_index][col_index]), 2)
row.append(Text(f"{score}", justify="center", style=Style(color=color)))
table.add_row(*row, end_section=row_index == scores.shape[0] - 1)
if step_scores is not None:
Expand All @@ -248,7 +248,7 @@ def get_saliency_heatmap_rich(
style = lambda val: "bold" if abs(val) >= threshold else ""
score_row = [Text(step_score_name, style="bold")]
for score in step_score_values:
score_row.append(Text(f"{score:.2f}", justify="center", style=style(score)))
score_row.append(Text(f"{score:.2f}", justify="center", style=style(round(float(score), 2))))
table.add_row(*score_row, end_section=True)
return table

Expand Down

0 comments on commit 7b3c4fd

Please sign in to comment.