Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use decollate on torch tensors and if i want to stack meta tensors then i cant do so because i can't account for different metas in the meta tensor #1620

Open
0tist opened this issue Jan 19, 2024 · 1 comment

Comments

@0tist
Copy link

0tist commented Jan 19, 2024

Bug Description
I'm trying to use a custom collate function for a dataloader, while I create the custom dataloader, the scans are fetched randomly from a collection which is gathered from different machines(like MRI, CT machines). I want to preserve the meta for all these tensors, since I cant do that with monai's meta tensor, I made a separate list to store meta with index and then I passed the collated torch tensors.

Code for custom collate fn
`
def val_datalist_collate(self, batch):

    imgs = []
    labels = []
    batch_output = {}
    for sample in batch:
        imgs.append(sample['image'])
        lbl = sample['label']
        if self.enable_binary_class:
            lbl[lbl > 0] = 1
        labels.append(lbl)
        
    batch_output['image'] = torch.stack(imgs)
    batch_output['label'] = torch.stack(labels)
    
    return batch_output

`

To Reproduce
Make a custom torch tensor dataloader and then decollate the batch loaded from
the dataloader.
Btw i loaded decollate from monai.data
from monai.data import decollate_batch

Error log

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], [line 1](vscode-notebook-cell:?execution_count=7&line=1)
----> [1](vscode-notebook-cell:?execution_count=7&line=1) mm.train()

File [~/opet/opet/src/main.py:210](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:210), in ModelMaker.train(self)
    [207](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:207)             mlflow.log_param(f'{k}[/](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/){k_}', v_)
    [209](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:209) ##### RUN TRAINING ######
--> [210](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:210) self.trainer.train(self.exp_id, run_id)

File [~/opet/opet/src/segmentation_3D/core/train.py:70](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:70), in trainer.train(self, exp_id, run_id)
     [68](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:68) patience -= 1
     [69](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:69) if not(patience):
---> [70](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:70)     mean_dice_val, val_loss = self.validate(val_loader)
     [71](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:71)     state_dict = {'exp_id': exp_id,
     [72](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:72)                   'run_id': run_id,
     [73](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:73)                   'optimizer': self.optimizer.state_dict(),
     [74](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:74)                   'training_loss': epoch_tr_loss,
     [75](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:75)                   'val loss': val_loss,
     [76](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:76)                   'epoch': epoch+1}
     [77](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:77)     self.save_model(dice_val= mean_dice_val, **state_dict)

File [~/opet/opet/src/segmentation_3D/core/train.py:108](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:108), in trainer.validate(self, val_loader)
    [106](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:106)         x_val, y_val = val_sample['image'], val_sample['label']
    [107](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:107)         x_val, y_val = x_val.to(self.device), y_val.to(self.device)
--> [108](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:108)         loss_val = self.val_one_iter(x_val, y_val)
    [109](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:109)         epoch_val_loss += loss_val
    [111](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:111) dice_val = self.dice_metric.aggregate().item()

File [~/opet/opet/src/segmentation_3D/core/train.py:130](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:130), in trainer.val_one_iter(self, X, Y)
    [127](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:127) post_pred = AsDiscrete(argmax=True, to_onehot=self.n_classes)
    [129](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:129) val_outputs = sliding_window_inference(X, self.patch_size, self.num_samples, self.model)
--> [130](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:130) val_labels_list = decollate_batch(Y)
    [131](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:131) val_labels_convert = [
    [132](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:132)     post_label(val_label_tensor) for val_label_tensor in val_labels_list
    [133](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:133) ]
    [134](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:134) val_outputs_list = decollate_batch(val_outputs)

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:619](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:619), in decollate_batch(batch, detach, pad, fill_value)
    [617](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:617) # if of type MetaObj, decollate the metadata
    [618](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:618) if isinstance(batch, MetaObj):
--> [619](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:619)     for t, m in zip(out_list, decollate_batch(batch.meta)):
    [620](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:620)         if isinstance(t, MetaObj):
    [621](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:621)             t.meta = m

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631), in decollate_batch(batch, detach, pad, fill_value)
    [628](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:628)         return [t.item() for t in out_list]
    [629](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:629)     return list(out_list)
--> [631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631) b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)
    [632](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:632) if b <= 0:  # all non-iterable, single item "batch"? {"image": 1, "label": 1}
    [633](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:633)     return deco

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532), in _non_zipping_check(batch_data, detach, pad, fill_value)
    [530](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:530) _deco: Mapping | Sequence
    [531](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:531) if isinstance(batch_data, Mapping):
--> [532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532)     _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
    [533](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:533) elif isinstance(batch_data, Iterable):
    [534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534)     _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532), in <dictcomp>(.0)
    [530](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:530) _deco: Mapping | Sequence
    [531](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:531) if isinstance(batch_data, Mapping):
--> [532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532)     _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
    [533](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:533) elif isinstance(batch_data, Iterable):
    [534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534)     _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631), in decollate_batch(batch, detach, pad, fill_value)
    [628](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:628)         return [t.item() for t in out_list]
    [629](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:629)     return list(out_list)
--> [631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631) b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)
    [632](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:632) if b <= 0:  # all non-iterable, single item "batch"? {"image": 1, "label": 1}
    [633](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:633)     return deco

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534), in _non_zipping_check(batch_data, detach, pad, fill_value)
    [532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532)     _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
    [533](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:533) elif isinstance(batch_data, Iterable):
--> [534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534)     _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]
    [535](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:535) else:
    [536](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:536)     raise NotImplementedError(f"Unable to de-collate: {batch_data}, type: {type(batch_data)}.")

TypeError: iteration over a 0-d array
@KumoLiu
Copy link
Contributor

KumoLiu commented Jan 19, 2024

Hi @0tist, I think the list_data_collate in MONAI can collate MetaTensor.
https://github.com/Project-MONAI/MONAI/blob/facf17693410d41170edd8e94364b4f341369aea/monai/data/utils.py#L505

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants