Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make saliency interpreter GPU compatible #5656

Merged
merged 4 commits into from
Jun 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _aggregate_token_embeddings(
embeddings_list: List[torch.Tensor], token_offsets: List[torch.Tensor]
) -> List[numpy.ndarray]:
if len(token_offsets) == 0:
return [embeddings.numpy() for embeddings in embeddings_list]
return [embeddings.detach().cpu().numpy() for embeddings in embeddings_list]
aggregated_embeddings = []
# NOTE: This is assuming that embeddings and offsets come in the same order, which may not
# be true. But, the intersection of using multiple TextFields with mismatched indexers is
Expand All @@ -60,5 +60,5 @@ def _aggregate_token_embeddings(

# All the places where the span length is zero, write in zeros.
embeddings[(span_embeddings_len == 0).expand(embeddings.shape)] = 0
aggregated_embeddings.append(embeddings.numpy())
aggregated_embeddings.append(embeddings.detach().cpu().numpy())
return aggregated_embeddings