Skip to content

Commit

Permalink
fix the example of pytorch_lightning_mnist.py
Browse files Browse the repository at this point in the history
- remove unused arg parameters
- fix model test issue on GPU

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc committed Oct 27, 2021
1 parent a776186 commit 4c69ce8
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions examples/pytorch/pytorch_lightning_mnist.py
Expand Up @@ -24,22 +24,10 @@
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')
parser.add_argument('--data-dir',
help='location of the training dataset in the local filesystem (will be downloaded if needed)')

Expand Down Expand Up @@ -205,7 +193,7 @@ def on_train_end(self, trainer, model):
callbacks = [MyDummyCallback(), ModelCheckpoint(dirpath=ckpt_path)]

trainer = Trainer(accelerator='horovod',
gpus=(1 if torch.cuda.is_available() else 0),
gpus=(1 if args.cuda else 0),
callbacks=callbacks,
max_epochs=epochs,
limit_train_batches=train_percent,
Expand All @@ -214,6 +202,7 @@ def on_train_end(self, trainer, model):
num_sanity_val_steps=0)

trainer.fit(model)

if args.cuda:
model = model.cuda()
test()

0 comments on commit 4c69ce8

Please sign in to comment.