/
utils.py
103 lines (80 loc) · 3.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import math
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def pad_sents_char(sents, char_pad_token):
""" Pad list of sentences according to the longest sentence in the batch and max_word_length.
@param sents (list[list[list[int]]]): list of sentences, result of `words2charindices()`
from `vocab.py`
@param char_pad_token (int): index of the character-padding token
@returns sents_padded (list[list[list[int]]]): list of sentences where sentences/words shorter
than the max length sentence/word are padded out with the appropriate pad token, such that
each sentence in the batch now has same number of words and each word has an equal
number of characters
Output shape: (batch_size, max_sentence_length, max_word_length)
"""
# Words longer than 21 characters should be truncated
max_word_length = 21
sents_padded = []
longest_sent = max([len(sent) for sent in sents])
sents = [sent + [[]] *(longest_sent-len(sent)) for sent in sents]
sents_trunc = []
for sent in sents:
sents_trunc.append([word[:max_word_length-1] if len(word)>max_word_length else word for word in sent])
sents_padded = []
for sent in sents_trunc:
sents_padded.append([word + [char_pad_token] * (max_word_length - len(word)) for word in sent])
return sents_padded
def pad_sents(sents, pad_token):
""" Pad list of sentences according to the longest sentence in the batch.
@param sents (list[list[int]]): list of sentences, where each sentence
is represented as a list of words
@param pad_token (int): padding token
@returns sents_padded (list[list[int]]): list of sentences where sentences shorter
than the max length sentence are padded out with the pad_token, such that
each sentences in the batch now has equal length.
Output shape: (batch_size, max_sentence_length)
"""
sents_padded = []
max_len = max(len(s) for s in sents)
batch_size = len(sents)
for s in sents:
padded = [pad_token] * max_len
padded[:len(s)] = s
sents_padded.append(padded)
return sents_padded
def read_corpus(file_path, source):
""" Read file, where each sentence is dilineated by a `\n`.
@param file_path (str): path to file containing corpus
@param source (str): "tgt" or "src" indicating whether text
is of the source language or target language
"""
data = []
for line in open(file_path):
sent = line.strip().split(' ')
# only append <s> and </s> to the target sentence
if source == 'tgt':
sent = ['<s>'] + sent + ['</s>']
data.append(sent)
return data
def batch_iter(data, batch_size, shuffle=False):
""" Yield batches of source and target sentences reverse sorted by length (largest to smallest).
@param data (list of (src_sent, tgt_sent)): list of tuples containing source and target sentence
@param batch_size (int): batch size
@param shuffle (boolean): whether to randomly shuffle the dataset
"""
batch_num = math.ceil(len(data) / batch_size)
index_array = list(range(len(data)))
if shuffle:
np.random.shuffle(index_array)
for i in range(batch_num):
indices = index_array[i * batch_size: (i + 1) * batch_size]
examples = [data[idx] for idx in indices]
examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
src_sents = [e[0] for e in examples]
tgt_sents = [e[1] for e in examples]
yield src_sents, tgt_sents