fix shape training
This commit is contained in:
@@ -190,7 +190,7 @@ if __name__ == "__main__":
|
||||
precision=amp_type,
|
||||
callbacks=callbacks,
|
||||
accelerator="gpu",
|
||||
devices=training_cfg.num_gpus,
|
||||
devices=args.num_gpus,
|
||||
num_nodes=training_cfg.num_nodes,
|
||||
strategy=ddp_strategy,
|
||||
gradient_clip_val=training_cfg.get('gradient_clip_val'),
|
||||
|
||||
Reference in New Issue
Block a user