分布式训练场景

本教程演示如何记录并展示 PyTorchTrainingJob 分布式训练场景下的训练产生的信息、指标和文件(以 PyTorch 模型的数据并行训练为例)。

运行示例

请按照使用方法准备环境,然后前往本教程对应的示例,参照其 README 文档运行。

下面介绍训练脚本进行了哪些修改以追踪训练。该方法是通用的,可以推广到任意类型的 TrainingJob、DeepSpeedJob、ColossalAIJob、MPIJob 或 GenericJob。

准备训练脚本

准备一个 PyTorch 训练脚本,其模型对 MNIST 数据集的图像进行分类,使用 DistributedDataParallel(DDP)模块进行分布式训练,具体代码如下所示(这里使用与使用 PyTorchTrainingJob 进行数据并行训练相同的脚本)。接下来将在此脚本的基础上进行简单的修改以进行追踪。

torch_mnist_trainingjob_em.py
import argparse
import logging
import os
import shutil

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

parser = argparse.ArgumentParser(
    description='Distributed training of Keras model for MNIST with DDP.')
parser.add_argument(
    '--backend',
    type=str,
    help='Distributed backend',
    choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
    default=dist.Backend.GLOO)
parser.add_argument('--log_dir',
                    type=str,
                    default='/mnt/log',
                    help='Path of the TensorBoard log directory.')
parser.add_argument('--no_cuda',
                    action='store_true',
                    default=False,
                    help='Disable CUDA training.')
logging.basicConfig(format='%(message)s', level=logging.INFO)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.conv3 = nn.Conv2d(64, 64, 3, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dense1 = nn.Linear(576, 64)
        self.dense2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.dense1(x))
        output = F.softmax(self.dense2(x), dim=1)
        return output


def train(scheduler):
    global global_step
    for epoch in range(1, epochs + 1):
        model.train()
        for step, (data, target) in enumerate(train_loader, 1):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if step % (500 // world_size) == 0:
                train_loss = loss.item()
                logging.info(
                    'epoch {:d}/{:d}, batch {:5d}/{:d} with loss: {:.4f}'.
                    format(epoch, epochs, step, steps_per_epoch, train_loss))
                global_step = (epoch - 1) * steps_per_epoch + step

                if args.log_dir and rank == 0:
                    writer.add_scalar('train/loss', train_loss, global_step)

        scheduler.step()
        global_step = epoch * steps_per_epoch
        test(val=True, epoch=epoch)


def test(val=False, epoch=None):
    label = 'val' if val else 'test'
    model.eval()
    running_loss = 0.0
    correct = 0

    with torch.no_grad():
        loader = val_loader if val else test_loader
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            running_loss += loss.item()
            prediction = output.max(1)[1]
            correct += (prediction == target).sum().item()

    test_loss = running_loss / len(loader)
    test_accuracy = correct / len(loader.dataset)
    msg = '{:s} loss: {:.4f}, {:s} accuracy: {:.4f}'.format(
        label, test_loss, label, test_accuracy)
    if val:
        msg = 'epoch {:d}/{:d} with '.format(epoch, epochs) + msg
    logging.info(msg)

    if args.log_dir and rank == 0:
        writer.add_scalar('{:s}/loss'.format(label), test_loss, global_step)
        writer.add_scalar('{:s}/accuracy'.format(label), test_accuracy,
                          global_step)


if __name__ == '__main__':
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    if use_cuda:
        logging.info('Using CUDA')
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    logging.info('Using distributed PyTorch with {} backend'.format(
        args.backend))
    dist.init_process_group(backend=args.backend)
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    torch.manual_seed(1)

    model = Net().to(device)
    model = DDP(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001 * world_size)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5), (0.5))])
    train_dataset = datasets.MNIST(root='./data',
                                   train=True,
                                   download=False,
                                   transform=transform)
    train_dataset, val_dataset = torch.utils.data.random_split(
        train_dataset, [48000, 12000])
    test_dataset = datasets.MNIST(root='./data',
                                  train=False,
                                  download=False,
                                  transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=32 * world_size,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=400,
                                             shuffle=False,
                                             **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1000,
                                              shuffle=False,
                                              **kwargs)

    if args.log_dir and rank == 0:
        log_dir = args.log_dir
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir, ignore_errors=True)
        writer = SummaryWriter(log_dir)

    global_step = 0
    epochs = 10
    steps_per_epoch = len(train_loader)
    train(scheduler)
    test()

创建 Run

创建一个包含所有(想要记录的)超参数的 Python 字典。

from t9k import em

if __name__ == '__main__':
    ...
    hparams = {
        'batch_size': 32 * world_size,
        'epochs': 10,
        'learning_rate': 0.001 * world_size,
        'learning_rate_decay_period': 1,
        'learning_rate_decay_factor': 0.7,
        'conv_channels1': 32,
        'conv_channels2': 64,
        'conv_channels3': 64,
        'conv_kernel_size': 3,
        'maxpool_size': 2,
        'linear_features1': 64,
        'seed': 1,
    }
    ...

指定一个工作器负责在建立模型之前创建并初始化一个 Run 实例,传入名称和上面的超参数字典。

if __name__ == '__main__':
    ...
    if rank == 0:
        run = em.create_run(name='mnist_torch_distributed', hparams=hparams)
    ...

使用设定的超参数配置模型

使用上面的超参数字典的值替换直接提供的超参数值,以配置模型各层、数据集、优化器、训练流程等。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, hparams['conv_channels1'],
                               hparams['conv_kernel_size'], 1)
        self.conv2 = nn.Conv2d(hparams['conv_channels1'],
                               hparams['conv_channels2'],
                               hparams['conv_kernel_size'], 1)
        self.conv3 = nn.Conv2d(hparams['conv_channels2'],
                               hparams['conv_channels3'],
                               hparams['conv_kernel_size'], 1)
        self.pool = nn.MaxPool2d(hparams['maxpool_size'],
                                 hparams['maxpool_size'])
        self.dense1 = nn.Linear(576, hparams['linear_features1'])
        self.dense2 = nn.Linear(hparams['linear_features1'], 10)
    ...

if __name__ == '__main__':
    ...
    torch.manual_seed(hparams['seed'])
    ...
    optimizer = optim.Adam(model.parameters(), lr=hparams['learning_rate'])
    scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=hparams['learning_rate_decay_period'],
        gamma=hparams['learning_rate_decay_factor'])
    ...
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams['batch_size'],
        shuffle=True,
        **kwargs)
    ...
    epochs = hparams['epochs']
    ...

记录指标

在模型的训练和测试过程中,被指定的工作器调用 Run 实例的 log() 方法以记录模型在此期间产生的指标。

def train(scheduler):
    ...
            if step % (500 // world_size) == 0:
                train_loss = loss.item()
                logging.info(
                    'epoch {:d}/{:d}, batch {:5d}/{:d} with loss: {:.4f}'.
                    format(epoch, epochs, step, steps_per_epoch, train_loss))
                global_step = (epoch - 1) * steps_per_epoch + step

                if rank == 0:
                    run.log(type='train',                  # 记录训练指标
                            metrics={'loss': train_loss},  # 指标名称及相应值
                            step=global_step,              # 当前全局步数
                            epoch=epoch)                   # 当前回合数
    ...

def test(val=False, epoch=None):
    ...
    test_loss = running_loss / len(loader)
    test_accuracy = correct / len(loader.dataset)
    msg = '{:s} loss: {:.4f}, {:s} accuracy: {:.4f}'.format(
        label, test_loss, label, test_accuracy)
    if val:
        msg = 'epoch {:d}/{:d} with '.format(epoch, epochs) + msg
    logging.info(msg)

    if rank == 0:
        run.log(type=label,    # 记录验证/测试指标
                metrics={
                    'loss': test_loss,
                    'accuracy': test_accuracy,
                },
                step=global_step,
                epoch=epoch)

创建 Artifact 并添加模型检查点文件

在保存模型检查点文件之后,被指定的工作器创建并初始化一个新的 Artifact 实例,为其添加该检查点文件,并标记为 Run 的输出。

if __name__ == '__main__':
    ...
    if rank == 0:
        torch.save(model.state_dict(), 'model_state_dict.pt')
        model_artifact = em.create_artifact(name='mnist_torch_saved_model')
        model_artifact.add_file('model_state_dict.pt')
        run.mark_output(model_artifact)
    ...

结束和上传试验

模型的训练和测试结束后,被指定的工作器调用 Run 实例的 finish()upload() 方法以结束和上传 Run(Artifact 也会被一并上传)。在上传之前需要调用 em.login() 函数以登录到服务器。

if __name__ == '__main__':
    ...
    if rank == 0:
        run.finish()
        em.login()
        run.upload(folder='em-examples', make_folder=True)

检查 Run 和 Artifact

训练结束后,进入实验管理控制台,可以看到名为 mnist_torch_distributed 的 Run 及其输出的 Artifact 被上传:

uploaded

点击 Run 或 Artifact 的名称进入其详情页面,可以看到 Run 的平台信息、指标、超参数和数据流,以及 Artifact 的文件和数据流。并且它们的数据流是连通的。

run
artifact
dataflow