This repository has been archived by the owner on May 14, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 62
Got an error when using lazy. #33
Comments
Oh,the code I used lazy is like this: def get_batches(sz, pad=0):
for i in range(0, len(datatmp), sz):
n=0
srcdata = []
trgdata = []
for j in range(n, sz):
srcdata.append(datatmp[i+j][0])#appened is a list
trgdata.append(datatmp[i+j][1])#identical to beneath
a = randint(1, 2)
src_max_seq_length=max([len(srcdata[i]) for i in range(len(srcdata))])
trg_max_seq_length=max([len(trgdata[i]) for i in range(len(trgdata))])
# pad src to src_max_seq_length
for i in range(len(srcdata)):
srcdata[i] = srcdata[i] + [pad for j in range(src_max_seq_length-len(srcdata[i]))]
#pad trg to trg_max_seq_length
for i in range(len(trgdata)):
trgdata[i] = trgdata[i] + [pad for j in range(trg_max_seq_length-len(trgdata[i]))]
sr = np.ndarray(shape=(sz, src_max_seq_length))
tg = np.ndarray(shape=(sz, trg_max_seq_length))
for i in range(len(srcdata)):
for j in range(len(srcdata[i])):
sr[i][j] = srcdata[i][j]
for i in range(len(trgdata)):
for j in range(len(trgdata[i])):
tg[i][j] = trgdata[i][j]
#srcdata = np.array(srcdata)
#trgdata = np.array(trgdata)
srcdata = torch.from_numpy(sr)
trgdata = torch.from_numpy(tg)
src = Variable(srcdata, requires_grad=False).long()
trg = Variable(trgdata, requires_grad=False).long()
(src, trg) = lazy(src,trg, batch=0)#Here
yield Batch(src, trg, pad) |
Hmm, sorry for the late response. It seems to me that you're using PyTorch 0.4.* right? I didn't test versions <1 so I'm not sure where the issue comes from. If I had to guess, it's perhaps because of the mismatch between the API of |
I'm sorry but it has the same error when I use torch2.0.0 and just torch.from_numpy(not using Variable).
|
I see. Seems to be an oversight on my part where I didn't handle broadcasting mechanism with primitives. Thanks for the feedback! |
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
I'm doing a NMT task.I use my own data loading function rather than using torch dataset.I got an "int object doesn't has attribute 'size' " error.
Here's my data loading code:
ps:The code is adapted from 'Annotated Transformer'
The text was updated successfully, but these errors were encountered: