Automatic Mixed Precision
Automatic Mixed Precision#
Automatic mixed precision (AMP) lets you train your models faster by using a lower precision datatype for operations like linear layers and convolutions.
You can train your Torch model with AMP by:
Adding
ray.train.torch.accelerate()withamp=Trueto the top of your training function.Wrapping your optimizer with
ray.train.torch.prepare_optimizer().Replacing your backward call with
ray.train.torch.backward().
def train_func():
+ train.torch.accelerate(amp=True)
model = NeuralNetwork()
model = train.torch.prepare_model(model)
data_loader = DataLoader(my_dataset, batch_size=worker_batch_size)
data_loader = train.torch.prepare_data_loader(data_loader)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
+ optimizer = train.torch.prepare_optimizer(optimizer)
model.train()
for epoch in range(90):
for images, targets in dataloader:
optimizer.zero_grad()
outputs = model(images)
loss = torch.nn.functional.cross_entropy(outputs, targets)
- loss.backward()
+ train.torch.backward(loss)
optimizer.step()
...
Note
The performance of AMP varies based on GPU architecture, model type, and data shape. For certain workflows, AMP may perform worse than full-precision training.