Fine-tune dolly-v2-7b with Ray AIR LightningTrainer and FSDP#

In this example, we demonstrate how to use Ray AIR to fine-tune a dolly-v2-7b model. dolly-v2-12b is a 12 billion parameter causal language model created by Databricks, derived from EleutherAI’s Pythia-12b, and fine-tuned on a ~15K record instruction corpus.

We load the pre-trained model from the HuggingFace model hub into a LightningModule and launch an FSDP fine-tuning job across 16 T4 GPUs with the help of Ray LightningTrainer. It is also straightforward to fine-tune other similar large language models in a similar manner as shown in this example.

Before starting this example, we highly recommend reading Ray Train Key Concepts and Ray Data Key Concepts.

Set up ray cluster#

In this example, we are using a ray cluster with 16 g4dn.4xlarge instances. Each instance has one Tesla T4 GPU (16GiB Memory).

We define a runtime_env to install the necessary Python libraries on each node. You can skip this step if you have already installed all the required packages in your workers’ base image. We tested this example with pytorch_lightning==2.0.2 and transformers==4.29.2.

import ray

ray.init(
    runtime_env={
        "pip": [
            "datasets",
            "evaluate",
            "transformers>=4.26.0",
            "torch>=1.12.0",
            "pytorch_lightning>=2.0",
        ]
    }
)
MODEL_NAME = "databricks/dolly-v2-7b"

Prepare your data#

We are using tiny_shakespeare for fine-tuning, which contains 40,000 lines of Shakespeare from a variety of Shakespeare’s plays. Featured in Andrej Karpathy’s blog post ‘The Unreasonable Effectiveness of Recurrent Neural Networks’.

Dataset samples:

BAPTISTA:
I know him well: you are welcome for his sake.

GREMIO:
Saving your tale, Petruchio, I pray,
Let us, that are poor petitioners, speak too:
Baccare! you are marvellous forward.

PETRUCHIO:
O, pardon me, Signior Gremio; I would fain be doing.

Here, we have adopted similar pre-processing logic from another demo: GPT-J-6B Fine-Tuning with Ray AIR and DeepSpeed.

import ray
import pandas as pd
from datasets import load_dataset
from ray.data.preprocessors import BatchMapper, Chain
from transformers import AutoTokenizer, AutoModelForCausalLM

def split_text(batch: pd.DataFrame) -> pd.DataFrame:
    text = list(batch["text"])
    flat_text = "".join(text)
    split_text = [
        x.strip()
        for x in flat_text.split("\n")
        if x.strip() and not x.strip()[-1] == ":"
    ]
    return pd.DataFrame(split_text, columns=["text"])


def tokenize(batch: pd.DataFrame) -> dict:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    ret = tokenizer(
        list(batch["text"]),
        truncation=True,
        max_length=256,
        padding="max_length",
        return_tensors="np",
    )
    ret["labels"] = ret["input_ids"].copy()
    return dict(ret)

splitter = BatchMapper(split_text, batch_format="pandas")
tokenizer = BatchMapper(tokenize, batch_format="pandas")
preprocessor = Chain(splitter, tokenizer)

hf_dataset = load_dataset("tiny_shakespeare")
ray_datasets = {
    "train": ray.data.from_huggingface(hf_dataset["train"]),
    "validation": ray.data.from_huggingface(hf_dataset["validation"]),
    "test": ray.data.from_huggingface(hf_dataset["test"]),
}

We first split the original paragraphs into multiple sentences, then tokenize them. Here are some samples:

ds = ray_datasets["train"]
splitter.fit_transform(ds).take(10)
[{'text': 'Before we proceed any further, hear me speak.'},
 {'text': 'Speak, speak.'},
 {'text': 'You are all resolved rather to die than to famish?'},
 {'text': 'Resolved. resolved.'},
 {'text': 'First, you know Caius Marcius is chief enemy to the people.'},
 {'text': "We know't, we know't."},
 {'text': "Let us kill him, and we'll have corn at our own price."},
 {'text': "Is't a verdict?"},
 {'text': "No more talking on't; let it be done: away, away!"},
 {'text': 'One word, good citizens.'}]

Define your lightning model#

In this example, we use the dolly-v2-7b model for finetuning. It is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use. We load the model weights from Huggingface Model Hub and encapsulate it into a pl.LightningModule.

Note

Make sure you pass the FSDP wrapped model parameters self.trainer.model.parameters() into the optimizer, instead of self.model.parameters().

import torch
import pytorch_lightning as pl

class DollyV2Model(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.lr = lr
        self.eps = eps
        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        self.predictions = []
        self.references = []

    def forward(self, batch):
        outputs = self.model(
            batch["input_ids"], 
            attention_mask=batch["attention_mask"], 
            labels=batch["labels"]
        )
        return outputs.loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("train_loss", loss, prog_bar=True, on_step=True)
        return loss

    def configure_optimizers(self):
        if self.global_rank == 0:
            print(self.trainer.model)
        return torch.optim.AdamW(self.trainer.model.parameters(), lr=self.lr, eps=self.eps)

Configure your FSDP strategy#

As Dolly-v2-3b is a relatively large model, it cannot be properly fit into a single commercial GPU. In this example, we use the FSDP strategy to shard model parameters across multiple workers. This allows us to avoid GPU out-of-memory issues and support a larger global batch size.

Image source: Fully Sharded Data Parallel: faster AI training with fewer GPUs

Note

FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks. This was inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. You may refer to these blogs for more information:

To start trainig with Lightning’s FSDPStrategy, you only need to provide the initialization arguments in LightningConfigBuilder.strategy(). Behind the scenes, LightningTrainer handles the cluster environment settings and job launching.

import functools
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.train import RunConfig, ScalingConfig, CheckpointConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer

# Define the model sharding policy:
# Wrap every GPTNeoXLayer as its own FSDP instance
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls = {GPTNeoXLayer}
)

# Aggregate all arguments for LightningTrainer
lightning_config = (
    LightningConfigBuilder()
    .module(cls=DollyV2Model, lr=2e-5, eps=1e-8)
    .trainer(
        max_epochs=1, 
        accelerator="gpu", 
        precision="16-mixed",
    )
    .strategy(
        name="fsdp",
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        forward_prefetch=True,
        auto_wrap_policy=auto_wrap_policy,
        limit_all_gathers=True,
        activation_checkpointing=[GPTNeoXLayer],
    )
    .checkpointing(save_top_k=0, save_weights_only=True, save_last=True)
)

Tip

Some tips for FSDP configutarion:

  • sharding_strategy:

    • ShardingStrategy.NO_SHARD: Parameters, gradients, and optimizer states are not sharded. Similar to DDP.

    • ShardingStrategy.SHARD_GRAD_OP: Gradients and optimizer states are sharded during computation, and additionally, parameters are sharded outside computation. Similar to ZeRO stage-2.

    • ShardingStrategy.FULL_SHARD: Parameters, gradients, and optimizer states are sharded. It has minimal GRAM usage among the 3 options. Similar to ZeRO stage-3.

  • auto_wrap_policy:

    • Model layers are often wrapped with FSDP in a layered fashion. This means that only the layers in a single FSDP instance are required to aggregate all parameters to a single device during forwarding or backward calculations.

    • Use transformer_auto_wrap_policy to automatically wrap each Transformer Block into a single FSDP instance.

  • backward_prefetch and forward_prefetch:

    • Overlap the upcoming all-gather while executing the current forward/backward pass. It can improve throughput but may slightly increase peak memory usage.

Fine-tune with LightningTrainer#

num_workers = 16
batch_size_per_worker = 10

Note

Since this example runs with multiple nodes, we need to persist checkpoints and other outputs to some external storage for access after training has completed. You should set up cloud storage or NFS, then replace storage_path with your own cloud bucket URI or NFS path.

See the storage guide for more details.

storage_path="s3://your-bucket-here"  # TODO: Set up cloud storage
# storage_path="/mnt/path/to/nfs"     # TODO: Alternatively, set up NFS
from ray.tune.syncer import SyncConfig
# Save AIR checkpoints according to the performance on validation set
run_config = RunConfig(
    storage_path=storage_path,
    name="finetune_dolly-v2-7b",
    checkpoint_config=CheckpointConfig(),
    sync_config=SyncConfig(sync_artifacts=False),
)

# Scale the DDP training workload across 16 GPUs
# You can change this config based on your compute resources.
scaling_config = ScalingConfig(
    num_workers=num_workers, use_gpu=True, resources_per_worker={"CPU": 12, "GPU": 1}
)

trainer = LightningTrainer(
    lightning_config=lightning_config.build(),
    run_config=run_config,
    scaling_config=scaling_config,
    datasets={"train": ray_datasets["train"]},
    datasets_iter_config={"batch_size": batch_size_per_worker},
    preprocessor=preprocessor,
)
result = trainer.fit()

result

Tune Status

Current time:2023-05-05 01:03:12
Running for: 00:45:50.28
Memory: 35.4/124.4 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 0/272 CPUs, 0/16 GPUs (0.0/16.0 accelerator_type:T4)

Trial Status

Trial name status loc iter total time (s) train_loss epoch step
LightningTrainer_e0990_00000TERMINATED10.0.102.147:41219 1 2699.78 0.166992 0 135
2023-05-05 00:17:21,842	WARNING trial_runner.py:1607 -- The maximum number of pending trials has been automatically set to the number of available cluster CPUs, which is high (299 CPUs/pending trials). If you're running an experiment with a large number of trials, this could lead to scheduling overhead. In this case, consider setting the `TUNE_MAX_PENDING_TRIALS_PG` environment variable to the desired maximum number of concurrent trials.
(LightningTrainer pid=41219) 2023-05-05 00:17:28,673	INFO backend_executor.py:128 -- Starting distributed worker processes: ['41376 (10.0.102.147)', '8301 (10.0.67.96)', '8263 (10.0.103.36)', '27794 (10.0.105.149)', '8088 (10.0.110.210)', '8238 (10.0.106.19)', '8225 (10.0.81.63)', '8200 (10.0.106.22)', '8231 (10.0.90.160)', '8345 (10.0.98.168)', '28207 (10.0.76.146)', '8213 (10.0.115.72)', '8272 (10.0.92.209)', '8247 (10.0.74.31)', '27629 (10.0.68.102)', '8224 (10.0.88.86)']
(RayTrainWorker pid=41376) 2023-05-05 00:17:30,953	INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=16]

(pid=41219) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 0.0 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:00<?, ?it/s]
(pid=41219) - RandomizeBlockOrder: 0 active, 0 queued, 0.0 MiB objects, 0 output:   0%|          | 0/1 [00:00<?, ?it/s]
(LightningTrainer pid=41219)                                                                                                   2023-05-05 00:17:31,564	INFO streaming_executor.py:87 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper->BatchMapper] -> AllToAllOperator[RandomizeBlockOrder]

(pid=41219) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 0.0 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:00<?, ?it/s]
(LightningTrainer pid=41219)                                                                                                   2023-05-05 00:17:31,564	INFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)

(pid=41219) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 0.0 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:00<?, ?it/s]
(LightningTrainer pid=41219)                                                                                                   2023-05-05 00:17:31,565	INFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1.

(pid=41219) Running: 1.0/272.0 CPU, 0.0/16.0 GPU, 0.96 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:00<?, ?it/s]
Downloading (…)okenizer_config.json: 100%|██████████| 450/450 [00:00<00:00, 68.5kB/s]                                           

(pid=41219) Running: 1.0/272.0 CPU, 0.0/16.0 GPU, 0.96 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:02<?, ?it/s]
Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]                                                

(pid=41219) Running: 1.0/272.0 CPU, 0.0/16.0 GPU, 0.96 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:02<?, ?it/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.11M/2.11M [00:00<00:00, 28.0MB/s]                                       

(pid=41219) Running: 1.0/272.0 CPU, 0.0/16.0 GPU, 0.96 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:02<?, ?it/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 228/228 [00:00<00:00, 150kB/s]                                            

(pid=41219) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 0.0 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:07<?, ?it/s]   
(pid=41219) - RandomizeBlockOrder: 0 active, 0 queued, 0.0 MiB objects, 1 output:   0%|          | 0/1 [00:07<?, ?it/s]
(pid=41219) Running: 0.0/272.0 CPU, 0.0/16.0 GPU, 126.69 MiB/73.21 GiB object_store_memory:   0%|          | 0/1 [00:07<?, ?it/s]
                                                                                                                                         
Downloading (…)lve/main/config.json: 100%|██████████| 819/819 [00:00<00:00, 123kB/s]                                           (RayTrainWorker pid=8247, ip=10.0.74.31) 
Downloading pytorch_model.bin:   0%|          | 0.00/13.8G [00:00<?, ?B/s]
Downloading pytorch_model.bin:   0%|          | 21.0M/13.8G [00:00<01:28, 156MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 819/819 [00:00<00:00, 125kB/s] [repeated 15x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
Downloading pytorch_model.bin:   0%|          | 0.00/13.8G [00:00<?, ?B/s] [repeated 15x across cluster]
Downloading pytorch_model.bin:   7%|▋         | 975M/13.8G [00:04<01:04, 199MB/s] [repeated 613x across cluster]
Downloading pytorch_model.bin:  14%|█▍        | 1.92G/13.8G [00:09<00:57, 206MB/s]
Downloading pytorch_model.bin:  15%|█▌        | 2.11G/13.8G [00:10<00:53, 219MB/s] [repeated 619x across cluster]
Downloading pytorch_model.bin:  23%|██▎       | 3.19G/13.8G [00:15<00:51, 207MB/s] [repeated 610x across cluster]
Downloading pytorch_model.bin:  30%|███       | 4.20G/13.8G [00:20<00:46, 209MB/s] [repeated 643x across cluster]
Downloading pytorch_model.bin:  40%|███▉      | 5.52G/13.8G [00:24<00:35, 233MB/s] [repeated 637x across cluster]
Downloading pytorch_model.bin:  43%|████▎     | 5.97G/13.8G [00:30<00:38, 206MB/s] [repeated 614x across cluster]
Downloading pytorch_model.bin:  59%|█████▉    | 8.22G/13.8G [00:35<00:21, 260MB/s] [repeated 619x across cluster]
Downloading pytorch_model.bin:  65%|██████▌   | 9.05G/13.8G [00:40<00:20, 238MB/s] [repeated 621x across cluster]
Downloading pytorch_model.bin:  63%|██████▎   | 8.72G/13.8G [00:45<00:26, 191MB/s] [repeated 627x across cluster]
Downloading pytorch_model.bin:  82%|████████▏ | 11.4G/13.8G [00:50<00:11, 221MB/s] [repeated 621x across cluster]
Downloading pytorch_model.bin:  91%|█████████▏| 12.6G/13.8G [00:53<00:04, 267MB/s]
Downloading pytorch_model.bin:  91%|█████████▏| 12.7G/13.8G [00:53<00:04, 268MB/s]
Downloading pytorch_model.bin:  92%|█████████▏| 12.7G/13.8G [00:53<00:04, 267MB/s]
Downloading pytorch_model.bin:  83%|████████▎ | 11.5G/13.8G [00:55<00:11, 215MB/s] [repeated 597x across cluster]
Downloading pytorch_model.bin: 100%|██████████| 13.8G/13.8G [00:57<00:00, 239MB/s]
Downloading pytorch_model.bin:  92%|█████████▏| 12.7G/13.8G [00:58<00:04, 237MB/s] [repeated 119x across cluster]
Downloading pytorch_model.bin:  84%|████████▍ | 11.7G/13.8G [01:00<00:10, 198MB/s] [repeated 440x across cluster]
Downloading pytorch_model.bin:  90%|█████████ | 12.5G/13.8G [01:03<00:06, 217MB/s]
Downloading pytorch_model.bin:  96%|█████████▌| 13.3G/13.8G [01:03<00:02, 230MB/s] [repeated 233x across cluster]
Downloading pytorch_model.bin: 100%|██████████| 13.8G/13.8G [01:04<00:00, 214MB/s]
Downloading pytorch_model.bin:  91%|█████████ | 12.6G/13.8G [01:04<00:06, 203MB/s] [repeated 145x across cluster]
Downloading pytorch_model.bin:  98%|█████████▊| 13.6G/13.8G [01:08<00:01, 216MB/s] [repeated 241x across cluster]
Downloading pytorch_model.bin: 100%|██████████| 13.8G/13.8G [01:09<00:00, 200MB/s] [repeated 4x across cluster]
(RayTrainWorker pid=8231, ip=10.0.90.160) Using 16bit Automatic Mixed Precision (AMP)
Downloading pytorch_model.bin: 100%|█████████▉| 13.8G/13.8G [01:10<00:00, 207MB/s] [repeated 77x across cluster]
(RayTrainWorker pid=8088, ip=10.0.110.210) Using 16bit Automatic Mixed Precision (AMP)
(RayTrainWorker pid=8231, ip=10.0.90.160) Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/LightningTrainer_e0990_00000_0_2023-05-05_00-17-21/rank_8/lightning_logs
(RayTrainWorker pid=8345, ip=10.0.98.168) Using 16bit Automatic Mixed Precision (AMP) [repeated 4x across cluster]
(RayTrainWorker pid=8345, ip=10.0.98.168) Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/LightningTrainer_e0990_00000_0_2023-05-05_00-17-21/rank_9/lightning_logs [repeated 4x across cluster]
(RayTrainWorker pid=41376) GPU available: True (cuda), used: True
(RayTrainWorker pid=41376) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=41376) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=41376) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=8238, ip=10.0.106.19) Using 16bit Automatic Mixed Precision (AMP) [repeated 8x across cluster]
(RayTrainWorker pid=8213, ip=10.0.115.72) Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/LightningTrainer_e0990_00000_0_2023-05-05_00-17-21/rank_11/lightning_logs [repeated 7x across cluster]
(RayTrainWorker pid=8238, ip=10.0.106.19) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
(RayTrainWorker pid=8088, ip=10.0.110.210) Using 16bit Automatic Mixed Precision (AMP) [repeated 3x across cluster]
(RayTrainWorker pid=8088, ip=10.0.110.210) Missing logger folder: /home/ray/ray_results/finetune_dolly-v2-7b/LightningTrainer_e0990_00000_0_2023-05-05_00-17-21/rank_4/lightning_logs [repeated 4x across cluster]
(RayTrainWorker pid=41376) 
(RayTrainWorker pid=41376)   | Name  | Type               | Params
(RayTrainWorker pid=41376) ---------------------------------------------
(RayTrainWorker pid=41376) 0 | model | GPTNeoXForCausalLM | 402 M 
(RayTrainWorker pid=41376) ---------------------------------------------
(RayTrainWorker pid=41376) 402 M     Trainable params
(RayTrainWorker pid=41376) 0         Non-trainable params
(RayTrainWorker pid=41376) 402 M     Total params
(RayTrainWorker pid=41376) 1,611.039 Total estimated model params size (MB)
(RayTrainWorker pid=8088, ip=10.0.110.210) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [repeated 15x across cluster]
(RayTrainWorker pid=41376) FullyShardedDataParallel(
(RayTrainWorker pid=41376)   (_fsdp_wrapped_module): _LightningModuleWrapperBase(
(RayTrainWorker pid=41376)     (_forward_module): DollyV2Model(
(RayTrainWorker pid=41376)       (model): GPTNeoXForCausalLM(
(RayTrainWorker pid=41376)         (gpt_neox): GPTNeoXModel(
(RayTrainWorker pid=41376)           (embed_in): Embedding(50280, 4096)
(RayTrainWorker pid=41376)           (layers): ModuleList(
(RayTrainWorker pid=41376)             (0-31): 32 x FullyShardedDataParallel(
(RayTrainWorker pid=41376)               (_fsdp_wrapped_module): CheckpointWrapper(
(RayTrainWorker pid=41376)                 (_checkpoint_wrapped_module): GPTNeoXLayer(
(RayTrainWorker pid=41376)                   (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(RayTrainWorker pid=41376)                   (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(RayTrainWorker pid=41376)                   (attention): GPTNeoXAttention(
(RayTrainWorker pid=41376)                     (rotary_emb): RotaryEmbedding()
(RayTrainWorker pid=41376)                     (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
(RayTrainWorker pid=41376)                     (dense): Linear(in_features=4096, out_features=4096, bias=True)
(RayTrainWorker pid=41376)                   )
(RayTrainWorker pid=41376)                   (mlp): GPTNeoXMLP(
(RayTrainWorker pid=41376)                     (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
(RayTrainWorker pid=41376)                     (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
(RayTrainWorker pid=41376)                     (act): GELUActivation()
(RayTrainWorker pid=41376)                   )
(RayTrainWorker pid=41376)                 )
(RayTrainWorker pid=41376)               )
(RayTrainWorker pid=41376)             )
(RayTrainWorker pid=41376)           )
(RayTrainWorker pid=41376)           (final_layer_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
(RayTrainWorker pid=41376)         )
(RayTrainWorker pid=41376)         (embed_out): Linear(in_features=4096, out_features=50280, bias=False)
(RayTrainWorker pid=41376)       )
(RayTrainWorker pid=41376)     )
(RayTrainWorker pid=41376)   )
(RayTrainWorker pid=41376) )
(RayTrainWorker pid=41376) /home/ray/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
(RayTrainWorker pid=41376)   rank_zero_warn(
Epoch 0:   0%|          | 0/134 [00:00<?, ?it/s]
Epoch 0:   1%|          | 1/134 [00:19<43:46, 19.75s/it, v_num=0, train_loss=12.90]
Epoch 0:   1%|▏         | 2/134 [00:37<40:43, 18.51s/it, v_num=0, train_loss=12.50]
Epoch 0:   2%|▏         | 3/134 [00:54<39:20, 18.02s/it, v_num=0, train_loss=12.50]
Epoch 0:   3%|▎         | 4/134 [01:11<38:48, 17.91s/it, v_num=0, train_loss=12.50]
Epoch 0:   4%|▎         | 5/134 [01:28<38:14, 17.78s/it, v_num=0, train_loss=12.50]
Epoch 0:   4%|▍         | 6/134 [01:46<37:45, 17.70s/it, v_num=0, train_loss=12.50]
Epoch 0:   5%|▌         | 7/134 [02:03<37:17, 17.62s/it, v_num=0, train_loss=12.50]
Epoch 0:   6%|▌         | 8/134 [02:20<36:52, 17.56s/it, v_num=0, train_loss=12.50]
Epoch 0:   7%|▋         | 9/134 [02:37<36:30, 17.52s/it, v_num=0, train_loss=12.50]
Epoch 0:   7%|▋         | 9/134 [02:37<36:32, 17.54s/it, v_num=0, train_loss=12.50]
Epoch 0:   7%|▋         | 10/134 [02:55<36:12, 17.52s/it, v_num=0, train_loss=12.50]
Epoch 0:   7%|▋         | 10/134 [02:55<36:14, 17.54s/it, v_num=0, train_loss=0.669]
Epoch 0:   8%|▊         | 11/134 [03:12<35:55, 17.53s/it, v_num=0, train_loss=0.669]
Epoch 0:   8%|▊         | 11/134 [03:12<35:57, 17.54s/it, v_num=0, train_loss=0.663]
Epoch 0:   9%|▉         | 12/134 [03:30<35:38, 17.53s/it, v_num=0, train_loss=0.663]
Epoch 0:   9%|▉         | 12/134 [03:30<35:39, 17.54s/it, v_num=0, train_loss=0.604]
Epoch 0:  10%|▉         | 13/134 [03:47<35:20, 17.53s/it, v_num=0, train_loss=0.604]
Epoch 0:  10%|▉         | 13/134 [03:48<35:22, 17.54s/it, v_num=0, train_loss=0.601]
Epoch 0:  10%|█         | 14/134 [04:05<35:01, 17.51s/it, v_num=0, train_loss=0.601]
Epoch 0:  10%|█         | 14/134 [04:05<35:02, 17.52s/it, v_num=0, train_loss=0.586]
Epoch 0:  11%|█         | 15/134 [04:22<34:45, 17.53s/it, v_num=0, train_loss=0.586]
Epoch 0:  11%|█         | 15/134 [04:23<34:46, 17.54s/it, v_num=0, train_loss=0.551]
Epoch 0:  12%|█▏        | 16/134 [04:40<34:28, 17.53s/it, v_num=0, train_loss=0.551]
Epoch 0:  12%|█▏        | 16/134 [04:40<34:29, 17.54s/it, v_num=0, train_loss=0.516]
Epoch 0:  13%|█▎        | 17/134 [04:57<34:10, 17.52s/it, v_num=0, train_loss=0.516]
Epoch 0:  13%|█▎        | 17/134 [04:58<34:11, 17.53s/it, v_num=0, train_loss=0.521]
Epoch 0:  13%|█▎        | 18/134 [05:15<33:52, 17.52s/it, v_num=0, train_loss=0.521]
Epoch 0:  13%|█▎        | 18/134 [05:15<33:53, 17.53s/it, v_num=0, train_loss=0.511]
Epoch 0:  14%|█▍        | 19/134 [05:32<33:34, 17.52s/it, v_num=0, train_loss=0.511]
Epoch 0:  14%|█▍        | 19/134 [05:33<33:35, 17.53s/it, v_num=0, train_loss=0.470]
Epoch 0:  15%|█▍        | 20/134 [05:50<33:17, 17.52s/it, v_num=0, train_loss=0.470]
Epoch 0:  15%|█▍        | 20/134 [05:50<33:17, 17.53s/it, v_num=0, train_loss=0.443]
Epoch 0:  16%|█▌        | 21/134 [06:07<32:59, 17.52s/it, v_num=0, train_loss=0.443]
Epoch 0:  16%|█▌        | 21/134 [06:08<33:00, 17.53s/it, v_num=0, train_loss=0.466]
Epoch 0:  16%|█▋        | 22/134 [06:25<32:42, 17.52s/it, v_num=0, train_loss=0.466]
Epoch 0:  16%|█▋        | 22/134 [06:25<32:42, 17.53s/it, v_num=0, train_loss=0.434]
Epoch 0:  17%|█▋        | 23/134 [06:43<32:25, 17.53s/it, v_num=0, train_loss=0.434]
Epoch 0:  17%|█▋        | 23/134 [06:43<32:26, 17.53s/it, v_num=0, train_loss=0.403]
Epoch 0:  18%|█▊        | 24/134 [07:00<32:08, 17.53s/it, v_num=0, train_loss=0.403]
Epoch 0:  18%|█▊        | 24/134 [07:00<32:09, 17.54s/it, v_num=0, train_loss=0.370]
Epoch 0:  19%|█▊        | 25/134 [07:18<31:51, 17.53s/it, v_num=0, train_loss=0.370]
Epoch 0:  19%|█▊        | 25/134 [07:18<31:51, 17.54s/it, v_num=0, train_loss=0.361]
Epoch 0:  19%|█▉        | 26/134 [07:35<31:34, 17.54s/it, v_num=0, train_loss=0.361]
Epoch 0:  19%|█▉        | 26/134 [07:36<31:34, 17.54s/it, v_num=0, train_loss=0.383]
Epoch 0:  20%|██        | 27/134 [07:54<31:18, 17.56s/it, v_num=0, train_loss=0.383]
Epoch 0:  20%|██        | 27/134 [07:54<31:19, 17.56s/it, v_num=0, train_loss=0.360]
Epoch 0:  21%|██        | 28/134 [08:11<31:01, 17.56s/it, v_num=0, train_loss=0.360]
Epoch 0:  21%|██        | 28/134 [08:11<31:02, 17.57s/it, v_num=0, train_loss=0.382]
Epoch 0:  22%|██▏       | 29/134 [08:29<30:44, 17.57s/it, v_num=0, train_loss=0.382]
Epoch 0:  22%|██▏       | 29/134 [08:29<30:45, 17.57s/it, v_num=0, train_loss=0.328]
Epoch 0:  22%|██▏       | 30/134 [08:47<30:28, 17.58s/it, v_num=0, train_loss=0.328]
Epoch 0:  22%|██▏       | 30/134 [08:47<30:28, 17.58s/it, v_num=0, train_loss=0.342]
Epoch 0:  23%|██▎       | 31/134 [09:04<30:10, 17.57s/it, v_num=0, train_loss=0.342]
Epoch 0:  23%|██▎       | 31/134 [09:04<30:10, 17.58s/it, v_num=0, train_loss=0.303]
Epoch 0:  24%|██▍       | 32/134 [09:22<29:51, 17.57s/it, v_num=0, train_loss=0.303]
Epoch 0:  24%|██▍       | 32/134 [09:22<29:52, 17.57s/it, v_num=0, train_loss=0.326]
Epoch 0:  25%|██▍       | 33/134 [09:39<29:33, 17.56s/it, v_num=0, train_loss=0.326]
Epoch 0:  25%|██▍       | 33/134 [09:39<29:34, 17.57s/it, v_num=0, train_loss=0.285]
Epoch 0:  25%|██▌       | 34/134 [09:56<29:14, 17.55s/it, v_num=0, train_loss=0.285]
Epoch 0:  25%|██▌       | 34/134 [09:56<29:15, 17.55s/it, v_num=0, train_loss=0.321]
Epoch 0:  26%|██▌       | 35/134 [10:14<28:57, 17.55s/it, v_num=0, train_loss=0.321]
Epoch 0:  26%|██▌       | 35/134 [10:14<28:57, 17.55s/it, v_num=0, train_loss=0.341]
Epoch 0:  27%|██▋       | 36/134 [10:31<28:39, 17.55s/it, v_num=0, train_loss=0.341]
Epoch 0:  27%|██▋       | 36/134 [10:31<28:40, 17.55s/it, v_num=0, train_loss=0.296]
Epoch 0:  28%|██▊       | 37/134 [10:49<28:21, 17.54s/it, v_num=0, train_loss=0.296]
Epoch 0:  28%|██▊       | 37/134 [10:49<28:22, 17.55s/it, v_num=0, train_loss=0.288]
Epoch 0:  28%|██▊       | 38/134 [11:07<28:05, 17.55s/it, v_num=0, train_loss=0.288]
Epoch 0:  28%|██▊       | 38/134 [11:07<28:05, 17.56s/it, v_num=0, train_loss=0.280]
Epoch 0:  29%|██▉       | 39/134 [11:24<27:47, 17.55s/it, v_num=0, train_loss=0.280]
Epoch 0:  29%|██▉       | 39/134 [11:24<27:47, 17.55s/it, v_num=0, train_loss=0.257]
Epoch 0:  30%|██▉       | 40/134 [11:42<27:29, 17.55s/it, v_num=0, train_loss=0.257]
Epoch 0:  30%|██▉       | 40/134 [11:42<27:30, 17.56s/it, v_num=0, train_loss=0.271]
Epoch 0:  31%|███       | 41/134 [11:59<27:11, 17.55s/it, v_num=0, train_loss=0.271]
Epoch 0:  31%|███       | 41/134 [11:59<27:12, 17.55s/it, v_num=0, train_loss=0.243]
Epoch 0:  31%|███▏      | 42/134 [12:16<26:54, 17.54s/it, v_num=0, train_loss=0.243]
Epoch 0:  31%|███▏      | 42/134 [12:17<26:54, 17.55s/it, v_num=0, train_loss=0.267]
Epoch 0:  32%|███▏      | 43/134 [12:34<26:36, 17.54s/it, v_num=0, train_loss=0.267]
Epoch 0:  32%|███▏      | 43/134 [12:34<26:36, 17.54s/it, v_num=0, train_loss=0.249]
Epoch 0:  33%|███▎      | 44/134 [12:51<26:18, 17.54s/it, v_num=0, train_loss=0.249]
Epoch 0:  33%|███▎      | 44/134 [12:51<26:18, 17.54s/it, v_num=0, train_loss=0.262]
Epoch 0:  34%|███▎      | 45/134 [13:09<26:01, 17.54s/it, v_num=0, train_loss=0.262]
Epoch 0:  34%|███▎      | 45/134 [13:09<26:01, 17.54s/it, v_num=0, train_loss=0.185]
Epoch 0:  34%|███▍      | 46/134 [13:26<25:43, 17.54s/it, v_num=0, train_loss=0.185]
Epoch 0:  34%|███▍      | 46/134 [13:26<25:43, 17.54s/it, v_num=0, train_loss=0.261]
Epoch 0:  35%|███▌      | 47/134 [13:43<25:25, 17.53s/it, v_num=0, train_loss=0.261]
Epoch 0:  35%|███▌      | 47/134 [13:44<25:25, 17.53s/it, v_num=0, train_loss=0.233]
Epoch 0:  36%|███▌      | 48/134 [14:01<25:07, 17.53s/it, v_num=0, train_loss=0.233]
Epoch 0:  36%|███▌      | 48/134 [14:01<25:07, 17.53s/it, v_num=0, train_loss=0.255]
Epoch 0:  37%|███▋      | 49/134 [14:18<24:50, 17.53s/it, v_num=0, train_loss=0.255]
Epoch 0:  37%|███▋      | 49/134 [14:19<24:50, 17.53s/it, v_num=0, train_loss=0.282]
Epoch 0:  37%|███▋      | 50/134 [14:36<24:32, 17.53s/it, v_num=0, train_loss=0.282]
Epoch 0:  37%|███▋      | 50/134 [14:36<24:32, 17.53s/it, v_num=0, train_loss=0.202]
Epoch 0:  38%|███▊      | 51/134 [14:54<24:15, 17.53s/it, v_num=0, train_loss=0.202]
Epoch 0:  38%|███▊      | 51/134 [14:54<24:15, 17.54s/it, v_num=0, train_loss=0.239]
Epoch 0:  39%|███▉      | 52/134 [15:11<23:57, 17.53s/it, v_num=0, train_loss=0.239]
Epoch 0:  39%|███▉      | 52/134 [15:11<23:57, 17.53s/it, v_num=0, train_loss=0.227]
Epoch 0:  40%|███▉      | 53/134 [15:28<23:39, 17.53s/it, v_num=0, train_loss=0.227]
Epoch 0:  40%|███▉      | 53/134 [15:29<23:39, 17.53s/it, v_num=0, train_loss=0.240]
Epoch 0:  40%|████      | 54/134 [15:46<23:21, 17.52s/it, v_num=0, train_loss=0.240]
Epoch 0:  40%|████      | 54/134 [15:46<23:21, 17.52s/it, v_num=0, train_loss=0.205]
Epoch 0:  41%|████      | 55/134 [16:03<23:04, 17.52s/it, v_num=0, train_loss=0.205]
Epoch 0:  41%|████      | 55/134 [16:03<23:04, 17.53s/it, v_num=0, train_loss=0.218]
Epoch 0:  42%|████▏     | 56/134 [16:21<22:47, 17.53s/it, v_num=0, train_loss=0.218]
Epoch 0:  42%|████▏     | 56/134 [16:21<22:47, 17.53s/it, v_num=0, train_loss=0.199]
Epoch 0:  43%|████▎     | 57/134 [16:38<22:29, 17.52s/it, v_num=0, train_loss=0.199]
Epoch 0:  43%|████▎     | 57/134 [16:39<22:29, 17.53s/it, v_num=0, train_loss=0.194]
Epoch 0:  43%|████▎     | 58/134 [16:56<22:11, 17.52s/it, v_num=0, train_loss=0.194]
Epoch 0:  43%|████▎     | 58/134 [16:56<22:11, 17.53s/it, v_num=0, train_loss=0.193]
Epoch 0:  44%|████▍     | 59/134 [17:13<21:54, 17.52s/it, v_num=0, train_loss=0.193]
Epoch 0:  44%|████▍     | 59/134 [17:13<21:54, 17.52s/it, v_num=0, train_loss=0.204]
Epoch 0:  45%|████▍     | 60/134 [17:31<21:36, 17.52s/it, v_num=0, train_loss=0.204]
Epoch 0:  45%|████▍     | 60/134 [17:31<21:36, 17.52s/it, v_num=0, train_loss=0.197]
Epoch 0:  46%|████▌     | 61/134 [17:48<21:18, 17.52s/it, v_num=0, train_loss=0.197]
Epoch 0:  46%|████▌     | 61/134 [17:48<21:18, 17.52s/it, v_num=0, train_loss=0.211]
Epoch 0:  46%|████▋     | 62/134 [18:06<21:01, 17.52s/it, v_num=0, train_loss=0.211]
Epoch 0:  46%|████▋     | 62/134 [18:06<21:01, 17.52s/it, v_num=0, train_loss=0.203]
Epoch 0:  47%|████▋     | 63/134 [18:23<20:43, 17.52s/it, v_num=0, train_loss=0.203]
Epoch 0:  47%|████▋     | 63/134 [18:23<20:43, 17.52s/it, v_num=0, train_loss=0.217]
Epoch 0:  48%|████▊     | 64/134 [18:41<20:26, 17.52s/it, v_num=0, train_loss=0.217]
Epoch 0:  48%|████▊     | 64/134 [18:41<20:26, 17.52s/it, v_num=0, train_loss=0.214]
Epoch 0:  49%|████▊     | 65/134 [18:58<20:08, 17.51s/it, v_num=0, train_loss=0.214]
Epoch 0:  49%|████▊     | 65/134 [18:58<20:08, 17.52s/it, v_num=0, train_loss=0.215]
Epoch 0:  49%|████▉     | 66/134 [19:15<19:50, 17.51s/it, v_num=0, train_loss=0.215]
Epoch 0:  49%|████▉     | 66/134 [19:15<19:50, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  50%|█████     | 67/134 [19:33<19:33, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  50%|█████     | 67/134 [19:33<19:33, 17.52s/it, v_num=0, train_loss=0.207]
Epoch 0:  51%|█████     | 68/134 [19:50<19:15, 17.51s/it, v_num=0, train_loss=0.207]
Epoch 0:  51%|█████     | 68/134 [19:50<19:15, 17.51s/it, v_num=0, train_loss=0.242]
Epoch 0:  51%|█████▏    | 69/134 [20:08<18:58, 17.51s/it, v_num=0, train_loss=0.242]
Epoch 0:  51%|█████▏    | 69/134 [20:08<18:58, 17.51s/it, v_num=0, train_loss=0.196]
Epoch 0:  52%|█████▏    | 70/134 [20:25<18:40, 17.51s/it, v_num=0, train_loss=0.196]
Epoch 0:  52%|█████▏    | 70/134 [20:25<18:40, 17.51s/it, v_num=0, train_loss=0.224]
Epoch 0:  53%|█████▎    | 71/134 [20:43<18:23, 17.51s/it, v_num=0, train_loss=0.224]
Epoch 0:  53%|█████▎    | 71/134 [20:43<18:23, 17.51s/it, v_num=0, train_loss=0.212]
Epoch 0:  54%|█████▎    | 72/134 [21:00<18:05, 17.51s/it, v_num=0, train_loss=0.212]
Epoch 0:  54%|█████▎    | 72/134 [21:00<18:05, 17.51s/it, v_num=0, train_loss=0.189]
Epoch 0:  54%|█████▍    | 73/134 [21:18<17:48, 17.51s/it, v_num=0, train_loss=0.189]
Epoch 0:  54%|█████▍    | 73/134 [21:18<17:48, 17.51s/it, v_num=0, train_loss=0.240]
Epoch 0:  55%|█████▌    | 74/134 [21:35<17:30, 17.51s/it, v_num=0, train_loss=0.240]
Epoch 0:  55%|█████▌    | 74/134 [21:35<17:30, 17.51s/it, v_num=0, train_loss=0.233]
Epoch 0:  56%|█████▌    | 75/134 [21:53<17:12, 17.51s/it, v_num=0, train_loss=0.233]
Epoch 0:  56%|█████▌    | 75/134 [21:53<17:13, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  57%|█████▋    | 76/134 [22:10<16:55, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  57%|█████▋    | 76/134 [22:10<16:55, 17.51s/it, v_num=0, train_loss=0.177]
Epoch 0:  57%|█████▋    | 77/134 [22:27<16:37, 17.51s/it, v_num=0, train_loss=0.177]
Epoch 0:  57%|█████▋    | 77/134 [22:28<16:37, 17.51s/it, v_num=0, train_loss=0.187]
Epoch 0:  58%|█████▊    | 78/134 [22:45<16:20, 17.50s/it, v_num=0, train_loss=0.187]
Epoch 0:  58%|█████▊    | 78/134 [22:45<16:20, 17.51s/it, v_num=0, train_loss=0.178]
Epoch 0:  59%|█████▉    | 79/134 [23:02<16:02, 17.51s/it, v_num=0, train_loss=0.178]
Epoch 0:  59%|█████▉    | 79/134 [23:03<16:02, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  60%|█████▉    | 80/134 [23:21<15:45, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  60%|█████▉    | 80/134 [23:21<15:45, 17.51s/it, v_num=0, train_loss=0.244]
Epoch 0:  60%|██████    | 81/134 [23:38<15:28, 17.51s/it, v_num=0, train_loss=0.244]
Epoch 0:  60%|██████    | 81/134 [23:38<15:28, 17.51s/it, v_num=0, train_loss=0.225]
Epoch 0:  61%|██████    | 82/134 [23:56<15:10, 17.51s/it, v_num=0, train_loss=0.225]
Epoch 0:  61%|██████    | 82/134 [23:56<15:10, 17.52s/it, v_num=0, train_loss=0.150]
Epoch 0:  62%|██████▏   | 83/134 [24:13<14:53, 17.51s/it, v_num=0, train_loss=0.150]
Epoch 0:  62%|██████▏   | 83/134 [24:13<14:53, 17.51s/it, v_num=0, train_loss=0.211]
Epoch 0:  63%|██████▎   | 84/134 [24:31<14:35, 17.51s/it, v_num=0, train_loss=0.211]
Epoch 0:  63%|██████▎   | 84/134 [24:31<14:35, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  63%|██████▎   | 85/134 [24:48<14:17, 17.51s/it, v_num=0, train_loss=0.216]
Epoch 0:  63%|██████▎   | 85/134 [24:48<14:18, 17.51s/it, v_num=0, train_loss=0.217]
Epoch 0:  64%|██████▍   | 86/134 [25:06<14:00, 17.51s/it, v_num=0, train_loss=0.217]
Epoch 0:  64%|██████▍   | 86/134 [25:06<14:00, 17.51s/it, v_num=0, train_loss=0.236]
Epoch 0:  65%|██████▍   | 87/134 [25:23<13:42, 17.51s/it, v_num=0, train_loss=0.236]
Epoch 0:  65%|██████▍   | 87/134 [25:23<13:43, 17.51s/it, v_num=0, train_loss=0.276]
Epoch 0:  66%|██████▌   | 88/134 [25:40<13:25, 17.51s/it, v_num=0, train_loss=0.276]
Epoch 0:  66%|██████▌   | 88/134 [25:40<13:25, 17.51s/it, v_num=0, train_loss=0.262]
Epoch 0:  66%|██████▋   | 89/134 [25:58<13:07, 17.51s/it, v_num=0, train_loss=0.262]
Epoch 0:  66%|██████▋   | 89/134 [25:58<13:07, 17.51s/it, v_num=0, train_loss=0.244]
Epoch 0:  67%|██████▋   | 90/134 [26:15<12:50, 17.51s/it, v_num=0, train_loss=0.244]
Epoch 0:  67%|██████▋   | 90/134 [26:15<12:50, 17.51s/it, v_num=0, train_loss=0.246]
Epoch 0:  68%|██████▊   | 91/134 [26:33<12:32, 17.51s/it, v_num=0, train_loss=0.246]
Epoch 0:  68%|██████▊   | 91/134 [26:33<12:32, 17.51s/it, v_num=0, train_loss=0.261]
Epoch 0:  69%|██████▊   | 92/134 [26:50<12:15, 17.51s/it, v_num=0, train_loss=0.261]
Epoch 0:  69%|██████▊   | 92/134 [26:50<12:15, 17.51s/it, v_num=0, train_loss=0.174]
Epoch 0:  69%|██████▉   | 93/134 [27:08<11:57, 17.51s/it, v_num=0, train_loss=0.174]
Epoch 0:  69%|██████▉   | 93/134 [27:08<11:57, 17.51s/it, v_num=0, train_loss=0.219]
Epoch 0:  70%|███████   | 94/134 [27:25<11:40, 17.50s/it, v_num=0, train_loss=0.219]
Epoch 0:  70%|███████   | 94/134 [27:25<11:40, 17.50s/it, v_num=0, train_loss=0.225]
Epoch 0:  71%|███████   | 95/134 [27:42<11:22, 17.50s/it, v_num=0, train_loss=0.225]
Epoch 0:  71%|███████   | 95/134 [27:42<11:22, 17.50s/it, v_num=0, train_loss=0.208]
Epoch 0:  72%|███████▏  | 96/134 [27:59<11:04, 17.50s/it, v_num=0, train_loss=0.208]
Epoch 0:  72%|███████▏  | 96/134 [27:59<11:04, 17.50s/it, v_num=0, train_loss=0.211]
Epoch 0:  72%|███████▏  | 97/134 [28:16<10:47, 17.49s/it, v_num=0, train_loss=0.211]
Epoch 0:  72%|███████▏  | 97/134 [28:17<10:47, 17.50s/it, v_num=0, train_loss=0.226]
Epoch 0:  73%|███████▎  | 98/134 [28:34<10:29, 17.49s/it, v_num=0, train_loss=0.226]
Epoch 0:  73%|███████▎  | 98/134 [28:34<10:29, 17.50s/it, v_num=0, train_loss=0.148]
Epoch 0:  74%|███████▍  | 99/134 [28:51<10:12, 17.49s/it, v_num=0, train_loss=0.148]
Epoch 0:  74%|███████▍  | 99/134 [28:51<10:12, 17.49s/it, v_num=0, train_loss=0.187]
Epoch 0:  75%|███████▍  | 100/134 [29:08<09:54, 17.49s/it, v_num=0, train_loss=0.187]
Epoch 0:  75%|███████▍  | 100/134 [29:09<09:54, 17.49s/it, v_num=0, train_loss=0.189]
Epoch 0:  75%|███████▌  | 101/134 [29:26<09:37, 17.49s/it, v_num=0, train_loss=0.189]
Epoch 0:  75%|███████▌  | 101/134 [29:26<09:37, 17.49s/it, v_num=0, train_loss=0.153]
Epoch 0:  76%|███████▌  | 102/134 [29:43<09:19, 17.48s/it, v_num=0, train_loss=0.153]
Epoch 0:  76%|███████▌  | 102/134 [29:43<09:19, 17.49s/it, v_num=0, train_loss=0.256]
Epoch 0:  77%|███████▋  | 103/134 [30:00<09:01, 17.48s/it, v_num=0, train_loss=0.256]
Epoch 0:  77%|███████▋  | 103/134 [30:00<09:01, 17.48s/it, v_num=0, train_loss=0.243]
Epoch 0:  78%|███████▊  | 104/134 [30:17<08:44, 17.48s/it, v_num=0, train_loss=0.243]
Epoch 0:  78%|███████▊  | 104/134 [30:18<08:44, 17.48s/it, v_num=0, train_loss=0.144]
Epoch 0:  78%|███████▊  | 105/134 [30:35<08:26, 17.48s/it, v_num=0, train_loss=0.144]
Epoch 0:  78%|███████▊  | 105/134 [30:35<08:26, 17.48s/it, v_num=0, train_loss=0.194]
Epoch 0:  79%|███████▉  | 106/134 [30:52<08:09, 17.48s/it, v_num=0, train_loss=0.194]
Epoch 0:  79%|███████▉  | 106/134 [30:52<08:09, 17.48s/it, v_num=0, train_loss=0.164]
Epoch 0:  80%|███████▉  | 107/134 [31:10<07:52, 17.48s/it, v_num=0, train_loss=0.164]
Epoch 0:  80%|███████▉  | 107/134 [31:10<07:52, 17.49s/it, v_num=0, train_loss=0.217]
Epoch 0:  81%|████████  | 108/134 [31:28<07:34, 17.49s/it, v_num=0, train_loss=0.217]
Epoch 0:  81%|████████  | 108/134 [31:28<07:34, 17.49s/it, v_num=0, train_loss=0.180]
Epoch 0:  81%|████████▏ | 109/134 [31:46<07:17, 17.49s/it, v_num=0, train_loss=0.180]
Epoch 0:  81%|████████▏ | 109/134 [31:46<07:17, 17.49s/it, v_num=0, train_loss=0.195]
Epoch 0:  82%|████████▏ | 110/134 [32:03<06:59, 17.49s/it, v_num=0, train_loss=0.195]
Epoch 0:  82%|████████▏ | 110/134 [32:04<06:59, 17.49s/it, v_num=0, train_loss=0.197]
Epoch 0:  83%|████████▎ | 111/134 [32:21<06:42, 17.49s/it, v_num=0, train_loss=0.197]
Epoch 0:  83%|████████▎ | 111/134 [32:21<06:42, 17.49s/it, v_num=0, train_loss=0.251]
Epoch 0:  84%|████████▎ | 112/134 [32:38<06:24, 17.49s/it, v_num=0, train_loss=0.251]
Epoch 0:  84%|████████▎ | 112/134 [32:38<06:24, 17.49s/it, v_num=0, train_loss=0.231]
Epoch 0:  84%|████████▍ | 113/134 [32:56<06:07, 17.49s/it, v_num=0, train_loss=0.231]
Epoch 0:  84%|████████▍ | 113/134 [32:56<06:07, 17.49s/it, v_num=0, train_loss=0.211]
Epoch 0:  85%|████████▌ | 114/134 [33:13<05:49, 17.49s/it, v_num=0, train_loss=0.211]
Epoch 0:  85%|████████▌ | 114/134 [33:13<05:49, 17.49s/it, v_num=0, train_loss=0.173]
Epoch 0:  86%|████████▌ | 115/134 [33:31<05:32, 17.49s/it, v_num=0, train_loss=0.173]
Epoch 0:  86%|████████▌ | 115/134 [33:31<05:32, 17.49s/it, v_num=0, train_loss=0.175]
Epoch 0:  87%|████████▋ | 116/134 [33:48<05:14, 17.49s/it, v_num=0, train_loss=0.175]
Epoch 0:  87%|████████▋ | 116/134 [33:48<05:14, 17.49s/it, v_num=0, train_loss=0.156]
Epoch 0:  87%|████████▋ | 117/134 [34:06<04:57, 17.49s/it, v_num=0, train_loss=0.156]
Epoch 0:  87%|████████▋ | 117/134 [34:06<04:57, 17.49s/it, v_num=0, train_loss=0.149]
Epoch 0:  88%|████████▊ | 118/134 [34:23<04:39, 17.49s/it, v_num=0, train_loss=0.149]
Epoch 0:  88%|████████▊ | 118/134 [34:24<04:39, 17.49s/it, v_num=0, train_loss=0.170]
Epoch 0:  89%|████████▉ | 119/134 [34:41<04:22, 17.49s/it, v_num=0, train_loss=0.170]
Epoch 0:  89%|████████▉ | 119/134 [34:41<04:22, 17.49s/it, v_num=0, train_loss=0.220]
Epoch 0:  90%|████████▉ | 120/134 [34:58<04:04, 17.49s/it, v_num=0, train_loss=0.220]
Epoch 0:  90%|████████▉ | 120/134 [34:58<04:04, 17.49s/it, v_num=0, train_loss=0.246]
Epoch 0:  90%|█████████ | 121/134 [35:15<03:47, 17.49s/it, v_num=0, train_loss=0.246]
Epoch 0:  90%|█████████ | 121/134 [35:16<03:47, 17.49s/it, v_num=0, train_loss=0.238]
Epoch 0:  91%|█████████ | 122/134 [35:33<03:29, 17.49s/it, v_num=0, train_loss=0.238]
Epoch 0:  91%|█████████ | 122/134 [35:33<03:29, 17.49s/it, v_num=0, train_loss=0.230]
Epoch 0:  92%|█████████▏| 123/134 [35:50<03:12, 17.49s/it, v_num=0, train_loss=0.230]
Epoch 0:  92%|█████████▏| 123/134 [35:50<03:12, 17.49s/it, v_num=0, train_loss=0.189]
Epoch 0:  93%|█████████▎| 124/134 [36:08<02:54, 17.49s/it, v_num=0, train_loss=0.189]
Epoch 0:  93%|█████████▎| 124/134 [36:08<02:54, 17.49s/it, v_num=0, train_loss=0.140]
Epoch 0:  93%|█████████▎| 125/134 [36:25<02:37, 17.49s/it, v_num=0, train_loss=0.140]
Epoch 0:  93%|█████████▎| 125/134 [36:26<02:37, 17.49s/it, v_num=0, train_loss=0.158]
Epoch 0:  94%|█████████▍| 126/134 [36:43<02:19, 17.49s/it, v_num=0, train_loss=0.158]
Epoch 0:  94%|█████████▍| 126/134 [36:43<02:19, 17.49s/it, v_num=0, train_loss=0.168]
Epoch 0:  95%|█████████▍| 127/134 [37:00<02:02, 17.49s/it, v_num=0, train_loss=0.168]
Epoch 0:  95%|█████████▍| 127/134 [37:01<02:02, 17.49s/it, v_num=0, train_loss=0.182]
Epoch 0:  96%|█████████▌| 128/134 [37:18<01:44, 17.49s/it, v_num=0, train_loss=0.182]
Epoch 0:  96%|█████████▌| 128/134 [37:18<01:44, 17.49s/it, v_num=0, train_loss=0.204]
Epoch 0:  96%|█████████▋| 129/134 [37:35<01:27, 17.49s/it, v_num=0, train_loss=0.204]
Epoch 0:  96%|█████████▋| 129/134 [37:36<01:27, 17.49s/it, v_num=0, train_loss=0.237]
Epoch 0:  97%|█████████▋| 130/134 [37:53<01:09, 17.49s/it, v_num=0, train_loss=0.237]
Epoch 0:  97%|█████████▋| 130/134 [37:53<01:09, 17.49s/it, v_num=0, train_loss=0.234]
Epoch 0:  98%|█████████▊| 131/134 [38:10<00:52, 17.49s/it, v_num=0, train_loss=0.234]
Epoch 0:  98%|█████████▊| 131/134 [38:11<00:52, 17.49s/it, v_num=0, train_loss=0.204]
Epoch 0:  99%|█████████▊| 132/134 [38:28<00:34, 17.49s/it, v_num=0, train_loss=0.204]
Epoch 0:  99%|█████████▊| 132/134 [38:28<00:34, 17.49s/it, v_num=0, train_loss=0.202]
Epoch 0:  99%|█████████▉| 133/134 [38:46<00:17, 17.49s/it, v_num=0, train_loss=0.202]
Epoch 0:  99%|█████████▉| 133/134 [38:46<00:17, 17.49s/it, v_num=0, train_loss=0.170]
Epoch 0: 100%|██████████| 134/134 [39:03<00:00, 17.49s/it, v_num=0, train_loss=0.170]
Epoch 0: 100%|██████████| 134/134 [39:03<00:00, 17.49s/it, v_num=0, train_loss=0.161]
Epoch 0: : 135it [39:21, 17.49s/it, v_num=0, train_loss=0.161]                       
Epoch 0: : 135it [39:21, 17.49s/it, v_num=0, train_loss=0.167]

Trial Progress

Trial name _report_on date done epoch experiment_taghostname iterations_since_restorenode_ip pidshould_checkpoint step time_since_restore time_this_iter_s time_total_s timestamp train_loss training_iterationtrial_id
LightningTrainer_e0990_00000train_epoch_end2023-05-05_01-02-26True 0 0ip-10-0-102-147 110.0.102.14741219True 135 2699.78 2699.78 2699.78 1683273746 0.166992 1e0990_00000
(RayTrainWorker pid=41376) `Trainer.fit` stopped: `max_epochs=1` reached.
(RayTrainWorker pid=41376) RayFSDPStrategy: tearing down strategy...

We finished training in 2361s. The price for an on-demand g4dn.4xlarge instance is $1.204/hour, while a g4dn.4xlarge instance costs $2.176/hour. The total cost would be ($1.204 * 15 + $2.176) * 2699 / 3600 = $15.17.

Text-generation with HuggingFace Pipeline#

We can use the HuggingFace Pipeline to generate predictions from our fine-tuned model. Let’s input some prompts and see if our tuned Dolly can speak like Shakespeare:

from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")

dolly = result.checkpoint.get_model(model_class=DollyV2Model, map_location=torch.device("cpu"))

nlp_pipeline = pipeline(
    task="text-generation", 
    model=dolly.model, 
    tokenizer=tokenizer, 
    device_map="auto"
)
for prompt in ["This is", "I am", "Once more"]:
    print(nlp_pipeline(prompt, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id))
[{'generated_text': 'This is the very place, my lord, where I was born.'}]
[{'generated_text': 'I am a man of a thousand lives, and I will live.'}]
[{'generated_text': 'Once more, my lord, I beseech you, hear me speak.'}]

References: