mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
Run isort and black on python files
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
__version__ = "2.0.8"
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_qkvpacked_func,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
|
||||
+23
-18
@@ -2,12 +2,10 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class IndexFirstAxis(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
ctx.save_for_backward(indices)
|
||||
@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
|
||||
second_dim = other_shape.numel()
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
# return input[indices]
|
||||
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
|
||||
repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape)
|
||||
return torch.gather(
|
||||
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
||||
).reshape(-1, *other_shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
(indices,) = ctx.saved_tensors
|
||||
assert grad_output.ndim >= 2
|
||||
other_shape = grad_output.shape[1:]
|
||||
grad_output = rearrange(grad_output, 'b ... -> b (...)')
|
||||
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
|
||||
device=grad_output.device, dtype=grad_output.dtype)
|
||||
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
||||
grad_input = torch.zeros(
|
||||
[ctx.first_axis_dim, grad_output.shape[1]],
|
||||
device=grad_output.device,
|
||||
dtype=grad_output.dtype,
|
||||
)
|
||||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||||
# grad_input[indices] = grad_output
|
||||
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
|
||||
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
||||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
||||
|
||||
|
||||
@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
|
||||
|
||||
|
||||
class IndexPutFirstAxis(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, values, indices, first_axis_dim):
|
||||
ctx.save_for_backward(indices)
|
||||
assert indices.ndim == 1
|
||||
assert values.ndim >= 2
|
||||
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device,
|
||||
dtype=values.dtype)
|
||||
output = torch.zeros(
|
||||
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
||||
)
|
||||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
||||
output[indices] = values
|
||||
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
||||
@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
indices, = ctx.saved_tensors
|
||||
(indices,) = ctx.saved_tensors
|
||||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
||||
grad_values = grad_output[indices]
|
||||
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
||||
@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
|
||||
|
||||
|
||||
class IndexFirstAxisResidual(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, indices):
|
||||
ctx.save_for_backward(indices)
|
||||
@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_residual):
|
||||
indices, = ctx.saved_tensors
|
||||
(indices,) = ctx.saved_tensors
|
||||
assert grad_output.ndim >= 2
|
||||
other_shape = grad_output.shape[1:]
|
||||
assert grad_residual.shape[1:] == other_shape
|
||||
@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
|
||||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
||||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
||||
# so we write custom forward and backward to make it a bit faster.
|
||||
return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices,
|
||||
cu_seqlens, max_seqlen_in_batch)
|
||||
return (
|
||||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def pad_input(hidden_states, indices, batch, seqlen):
|
||||
@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
||||
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
# output[indices] = hidden_states
|
||||
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
||||
return rearrange(output, '(b s) ... -> b s ...', b=batch)
|
||||
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
@@ -45,40 +44,109 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||
|
||||
|
||||
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax):
|
||||
def _flash_attn_varlen_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_softmax,
|
||||
):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, None
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
return_softmax,
|
||||
None,
|
||||
)
|
||||
# if out.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
dropout_p, softmax_scale, causal, rng_state=None):
|
||||
def _flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None
|
||||
):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p,
|
||||
softmax_scale, causal, None, rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
|
||||
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, rng_state=None):
|
||||
def _flash_attn_varlen_backward(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
rng_state=None,
|
||||
):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
@@ -86,14 +154,18 @@ def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
|
||||
|
||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
qkv[:, :, 0],
|
||||
qkv[:, :, 1],
|
||||
qkv[:, :, 2],
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@@ -107,22 +179,41 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dqkv[:, :, 0],
|
||||
dqkv[:, :, 1],
|
||||
dqkv[:, :, 2],
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
return dqkv, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
max_seqlen,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@@ -137,23 +228,41 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
|
||||
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dqkv[:, 0],
|
||||
dqkv[:, 1],
|
||||
dqkv[:, 2],
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
ctx.max_seqlen,
|
||||
ctx.max_seqlen,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0
|
||||
q,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@@ -168,28 +277,58 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dkv[:, :, 0],
|
||||
dkv[:, :, 1],
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., : dout.shape[-1]]
|
||||
return dq, dkv, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
kv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_softmax,
|
||||
):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
q,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
|
||||
cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
@@ -204,24 +343,42 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
|
||||
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dkv[:, 0],
|
||||
dkv[:, 1],
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
ctx.max_seqlen_q,
|
||||
ctx.max_seqlen_k,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., : dout.shape[-1]]
|
||||
return dq, dkv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
q, k, v, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@@ -234,29 +391,60 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., : dout.shape[-1]]
|
||||
dv = dv[..., : dout.shape[-1]]
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_softmax,
|
||||
):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
|
||||
cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
@@ -269,18 +457,33 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
ctx.max_seqlen_q,
|
||||
ctx.max_seqlen_k,
|
||||
ctx.dropout_p,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., : dout.shape[-1]]
|
||||
dv = dv[..., : dout.shape[-1]]
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
def flash_attn_qkvpacked_func(
|
||||
qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
@@ -309,8 +512,9 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
|
||||
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
def flash_attn_kvpacked_func(
|
||||
q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
@@ -342,8 +546,9 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa
|
||||
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
def flash_attn_func(
|
||||
q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
@@ -373,8 +578,15 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False):
|
||||
def flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
return_attn_probs=False,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
@@ -408,9 +620,18 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
def flash_attn_varlen_kvpacked_func(
|
||||
q,
|
||||
kv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
return_attn_probs=False,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
@@ -446,14 +667,32 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnVarlenKVPackedFunc.apply(
|
||||
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs
|
||||
q,
|
||||
kv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_attn_probs,
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
return_attn_probs=False,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
@@ -487,6 +726,15 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnVarlenFunc.apply(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_attn_probs,
|
||||
)
|
||||
|
||||
+533
-205
File diff suppressed because it is too large
Load Diff
@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
TMP,
|
||||
L,
|
||||
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
stride_on,
|
||||
Z,
|
||||
H,
|
||||
N_CTX,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
@@ -100,9 +119,13 @@ def _fwd_kernel(
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
Out,
|
||||
DO,
|
||||
L,
|
||||
NewDO,
|
||||
Delta,
|
||||
BLOCK_M: tl.constexpr,
|
||||
D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
@@ -120,16 +143,36 @@ def _bwd_preprocess(
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out,
|
||||
DO,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
L,
|
||||
M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
Z,
|
||||
H,
|
||||
N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
@@ -203,7 +246,6 @@ def _bwd_kernel(
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
BLOCK = 128
|
||||
@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
tmp = torch.empty(
|
||||
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
|
||||
)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
tmp,
|
||||
L,
|
||||
m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
v.stride(3),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
o.stride(3),
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
q.shape[2],
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
|
||||
o,
|
||||
do,
|
||||
l,
|
||||
do_scaled,
|
||||
delta,
|
||||
BLOCK_M=ctx.BLOCK,
|
||||
D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
|
||||
# NOTE: kernel currently buggy for other values of `num_warps`
|
||||
num_warps = 8
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
ctx.sm_scale,
|
||||
o,
|
||||
do_scaled,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
l,
|
||||
m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
k.stride(3),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
v.stride(3),
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
BLOCK_M=ctx.BLOCK,
|
||||
BLOCK_N=ctx.BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq.to(q.dtype), dk, dv, None
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
|
||||
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
|
||||
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
||||
from flash_attn.flash_blocksparse_attn_interface import (
|
||||
convert_blockmask,
|
||||
flash_blocksparse_attn_func,
|
||||
)
|
||||
|
||||
|
||||
class FlashBlocksparseAttention(nn.Module):
|
||||
@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.1)
|
||||
"""
|
||||
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
|
||||
max_seq_length=2048, device=None, dtype=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sparsity_config,
|
||||
softmax_temp=None,
|
||||
attention_dropout=0.0,
|
||||
max_seq_length=2048,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
|
||||
self.softmax_temp = softmax_temp
|
||||
@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
|
||||
self.register_buffer("blockmask_converted", blockmask_converted)
|
||||
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
|
||||
|
||||
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
|
||||
max_s=None, need_weights=False, convert_mask=True):
|
||||
def forward(
|
||||
self,
|
||||
qkv,
|
||||
attn_mask=None,
|
||||
key_padding_mask=None,
|
||||
causal=False,
|
||||
cu_seqlens=None,
|
||||
max_s=None,
|
||||
need_weights=False,
|
||||
convert_mask=True,
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
|
||||
seqlen = qkv.shape[1]
|
||||
# Convert mask to take a subset
|
||||
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
|
||||
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
|
||||
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
|
||||
assert seqlen_rounded // 16 <= self.layout.shape[0], (
|
||||
seqlen_rounded // 256 <= self.layout.shape[1]
|
||||
)
|
||||
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal
|
||||
cu_seqlens = torch.arange(
|
||||
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_temp,
|
||||
causal=causal,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
||||
else:
|
||||
key_padding_mask_bool = key_padding_mask.bool_matrix
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
|
||||
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
||||
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
||||
output_unpad = flash_blocksparse_attn_func(
|
||||
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal
|
||||
x_unpad,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_temp,
|
||||
causal=causal,
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
||||
indices, batch_size, seqlen),
|
||||
'b s (h d) -> b s h d', h=nheads)
|
||||
else:
|
||||
assert max_s is not None
|
||||
seqlen = max_s
|
||||
# Convert mask to take a subset
|
||||
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
|
||||
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
|
||||
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
|
||||
assert seqlen_rounded // 16 <= self.layout.shape[0], (
|
||||
seqlen_rounded // 256 <= self.layout.shape[1]
|
||||
)
|
||||
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
|
||||
if convert_mask:
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_temp,
|
||||
causal=causal,
|
||||
)
|
||||
else:
|
||||
output = flash_blocksparse_attn_func(
|
||||
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
|
||||
max_s, softmax_scale=self.softmax_temp, causal=causal,
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
self.blockmask_converted,
|
||||
self.dropout_p if self.training else 0.0,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_temp,
|
||||
causal=causal,
|
||||
convert_mask=False,
|
||||
)
|
||||
|
||||
@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
|
||||
|
||||
|
||||
class FlashBlocksparseMHA(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
|
||||
attention_dropout=0.0, causal=False, max_seq_length=2048,
|
||||
device=None, dtype=None, **kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
sparsity_config,
|
||||
bias=True,
|
||||
batch_first=True,
|
||||
attention_dropout=0.0,
|
||||
causal=False,
|
||||
max_seq_length=2048,
|
||||
device=None,
|
||||
dtype=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert batch_first
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
|
||||
|
||||
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
||||
self.inner_attn = FlashBlocksparseAttention(
|
||||
sparsity_config, attention_dropout=attention_dropout,
|
||||
max_seq_length=max_seq_length, **factory_kwargs
|
||||
sparsity_config,
|
||||
attention_dropout=attention_dropout,
|
||||
max_seq_length=max_seq_length,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
|
||||
need_weights=False):
|
||||
def forward(
|
||||
self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
|
||||
):
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
||||
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights, causal=self.causal)
|
||||
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
|
||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
|
||||
context, attn_weights = self.inner_attn(
|
||||
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
|
||||
)
|
||||
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
|
||||
import flash_attn_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import flash_attn_cuda
|
||||
|
||||
|
||||
def convert_blockmask(blockmask, causal):
|
||||
"""Convert from the 0-1 format to the format used by the CUDA code.
|
||||
@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
|
||||
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
|
||||
causal, return_softmax):
|
||||
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
|
||||
max_s, softmax_scale, causal,
|
||||
return_softmax, None)
|
||||
def _flash_blocksparse_attn_forward(
|
||||
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
|
||||
):
|
||||
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
|
||||
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
|
||||
)
|
||||
# if context.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
S_dmask = rest[0] if return_softmax else None
|
||||
return context, softmax_lse, S_dmask
|
||||
|
||||
|
||||
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
|
||||
dropout_p, max_s, softmax_scale, causal):
|
||||
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
|
||||
blockmask, dropout_p, softmax_scale, max_s,
|
||||
causal, None)
|
||||
def _flash_blocksparse_attn_backward(
|
||||
dout,
|
||||
qkv,
|
||||
out,
|
||||
S_dmask,
|
||||
softmax_lse,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
dropout_p,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
causal,
|
||||
):
|
||||
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
|
||||
dout,
|
||||
qkv,
|
||||
out,
|
||||
S_dmask,
|
||||
softmax_lse,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
max_s,
|
||||
causal,
|
||||
None,
|
||||
)
|
||||
# if dqkv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
return dqkv
|
||||
|
||||
|
||||
class FlashBlocksparseAttnFun(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
|
||||
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
|
||||
return_softmax=False
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
dropout_p,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=False,
|
||||
)
|
||||
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
# S_dmask is None, temporarily use another tensor just to get it running
|
||||
dqkv = _flash_blocksparse_attn_backward(
|
||||
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
|
||||
ctx.max_s, ctx.softmax_scale, ctx.causal
|
||||
dout,
|
||||
qkv,
|
||||
context,
|
||||
context,
|
||||
softmax_lse,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
ctx.dropout_p,
|
||||
ctx.max_s,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
|
||||
# We duplicate code to return both the output and the softmax for testing
|
||||
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
|
||||
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
|
||||
# Save rng_state because the backward pass is gonna regenerate the dropout mask
|
||||
@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
|
||||
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
|
||||
return_softmax=True
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
dropout_p,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
return_softmax=True,
|
||||
)
|
||||
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dqkv = _flash_blocksparse_attn_backward(
|
||||
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
|
||||
ctx.max_s, ctx.softmax_scale, ctx.causal
|
||||
dout,
|
||||
qkv,
|
||||
context,
|
||||
S_dmask,
|
||||
softmax_lse,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
ctx.dropout_p,
|
||||
ctx.max_s,
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False, convert_mask=True):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
"""
|
||||
def flash_blocksparse_attn_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
blockmask,
|
||||
dropout_p,
|
||||
max_s,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
return_attn_probs=False,
|
||||
convert_mask=True,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation"""
|
||||
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
|
||||
if convert_mask:
|
||||
blockmask = convert_blockmask(blockmask, causal=causal)
|
||||
|
||||
+10
-14
@@ -17,13 +17,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from apex._autocast_utils import _cast_if_autocast_enabled
|
||||
from apex.transformer.enums import AttnMaskType
|
||||
|
||||
from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward
|
||||
from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block
|
||||
from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward
|
||||
from fused_softmax_lib import (
|
||||
scaled_masked_softmax_backward,
|
||||
scaled_masked_softmax_forward,
|
||||
scaled_masked_softmax_get_batch_per_block,
|
||||
scaled_upper_triang_masked_softmax_backward,
|
||||
scaled_upper_triang_masked_softmax_forward,
|
||||
)
|
||||
|
||||
|
||||
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_upper_triang_masked_softmax_forward(
|
||||
inputs, scale_t[0]
|
||||
)
|
||||
softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_masked_softmax_backward(
|
||||
output_grads, softmax_results, scale_t[0]
|
||||
)
|
||||
input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
|
||||
self.input_in_fp16 = input_in_fp16
|
||||
self.input_in_bf16 = input_in_bf16
|
||||
if self.input_in_fp16 and self.input_in_bf16:
|
||||
raise RuntimeError(
|
||||
"both fp16 and bf16 flags cannot be active at the same time."
|
||||
)
|
||||
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
|
||||
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
|
||||
@@ -4,11 +4,10 @@
|
||||
from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import _assert
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
except ImportError:
|
||||
@@ -16,18 +15,18 @@ except ImportError:
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
fused_bias_fc=False,
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
fused_bias_fc=False,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = _pair(img_size)
|
||||
@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
raise ImportError("fused_dense is not installed")
|
||||
|
||||
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
|
||||
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
|
||||
@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)',
|
||||
p1=self.patch_size[0], p2=self.patch_size[1]))
|
||||
_assert(
|
||||
H == self.img_size[0],
|
||||
f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
|
||||
)
|
||||
_assert(
|
||||
W == self.img_size[1],
|
||||
f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
|
||||
)
|
||||
x = self.proj(
|
||||
rearrange(
|
||||
x,
|
||||
"b c (h p1) (w p2) -> b h w (c p1 p2)",
|
||||
p1=self.patch_size[0],
|
||||
p2=self.patch_size[1],
|
||||
)
|
||||
)
|
||||
if self.flatten:
|
||||
x = rearrange(x, 'b h w c -> b (h w) c')
|
||||
x = rearrange(x, "b h w c -> b (h w) c")
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
+172
-80
@@ -1,13 +1,11 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
from typing import Tuple, Optional
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import rotary_emb
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def rotate_half(x, interleaved=False):
|
||||
@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(cos, 's d -> s 1 (2 d)')
|
||||
sin = repeat(sin, 's d -> s 1 (2 d)')
|
||||
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:]], dim=-1)
|
||||
cos = repeat(cos, "s d -> s 1 (2 d)")
|
||||
sin = repeat(sin, "s d -> s 1 (2 d)")
|
||||
return torch.cat(
|
||||
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class ApplyRotaryEmb(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
|
||||
"""
|
||||
@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
||||
if inplace:
|
||||
o1, o2 = x1, x2
|
||||
else:
|
||||
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
|
||||
else (out_ro[..., ::2], out_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
|
||||
o1, o2 = (
|
||||
out_ro.chunk(2, dim=-1)
|
||||
if not interleaved
|
||||
else (out_ro[..., ::2], out_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
x1,
|
||||
x2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
o1,
|
||||
o2,
|
||||
False,
|
||||
)
|
||||
if not inplace and rotary_dim < headdim:
|
||||
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
||||
ctx.save_for_backward(cos, sin)
|
||||
@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
||||
rotary_dim *= 2
|
||||
inplace = ctx.inplace
|
||||
do_ro = do[..., :rotary_dim]
|
||||
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (do_ro[..., ::2], do_ro[..., 1::2]))
|
||||
do1, do2 = (
|
||||
do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
|
||||
)
|
||||
dx = torch.empty_like(do) if not inplace else do
|
||||
if inplace:
|
||||
dx1, dx2 = do1, do2
|
||||
else:
|
||||
dx_ro = dx[..., :rotary_dim]
|
||||
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
|
||||
dx1, dx2 = (
|
||||
dx_ro.chunk(2, dim=-1)
|
||||
if not ctx.interleaved
|
||||
else (dx_ro[..., ::2], dx_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
do1,
|
||||
do2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
dx1,
|
||||
dx2,
|
||||
True,
|
||||
)
|
||||
if not inplace and rotary_dim < headdim:
|
||||
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
||||
return dx, None, None, None, None
|
||||
@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
|
||||
|
||||
|
||||
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
||||
"""
|
||||
@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
||||
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
||||
rotary_emb.apply_rotary(
|
||||
q1,
|
||||
q2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
q1,
|
||||
q2,
|
||||
False,
|
||||
)
|
||||
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
||||
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
||||
rotary_emb.apply_rotary(
|
||||
k1,
|
||||
k2,
|
||||
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
|
||||
k1,
|
||||
k2,
|
||||
False,
|
||||
)
|
||||
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
||||
ctx.interleaved = interleaved
|
||||
return qkv
|
||||
@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
||||
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
||||
dq1, dq2 = (
|
||||
dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
dq1,
|
||||
dq2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
dq1,
|
||||
dq2,
|
||||
True,
|
||||
)
|
||||
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
||||
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
||||
dk1, dk2 = (
|
||||
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
dk1,
|
||||
dk2,
|
||||
rearrange(cos_k[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin_k[:seqlen], "s d -> s 1 d"),
|
||||
dk1,
|
||||
dk2,
|
||||
True,
|
||||
)
|
||||
return dqkv, None, None, None, None, None
|
||||
|
||||
|
||||
@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
|
||||
|
||||
class ApplyRotaryEmbKV_(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, kv, cos, sin, interleaved=False):
|
||||
"""
|
||||
@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
|
||||
assert seqlen <= rotary_seqlen
|
||||
k_ro = kv[:, :, 0, :, :rotary_dim]
|
||||
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
||||
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2,
|
||||
False) # conj=False since this is the forward pass
|
||||
rotary_emb.apply_rotary(
|
||||
k1,
|
||||
k2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
k1,
|
||||
k2,
|
||||
False,
|
||||
) # conj=False since this is the forward pass
|
||||
ctx.save_for_backward(cos, sin)
|
||||
ctx.interleaved = interleaved
|
||||
return kv
|
||||
@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
|
||||
rotary_dim = cos.shape[-1]
|
||||
rotary_dim *= 2
|
||||
dk_ro = dkv[:, :, 0, :, :rotary_dim]
|
||||
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
||||
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
||||
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
||||
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2,
|
||||
True) # conj=True since this is the backward pass
|
||||
dk1, dk2 = (
|
||||
dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
|
||||
)
|
||||
rotary_emb.apply_rotary(
|
||||
dk1,
|
||||
dk2,
|
||||
rearrange(cos[:seqlen], "s d -> s 1 d"),
|
||||
rearrange(sin[:seqlen], "s d -> s 1 d"),
|
||||
dk1,
|
||||
dk2,
|
||||
True,
|
||||
) # conj=True since this is the backward pass
|
||||
return dkv, None, None, None
|
||||
|
||||
|
||||
@@ -214,21 +275,28 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None,
|
||||
pos_idx_in_fp32=True, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base=10000.0,
|
||||
interleaved=False,
|
||||
scale_base=None,
|
||||
pos_idx_in_fp32=True,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
||||
otherwise they might be in lower precision.
|
||||
This option was added because previously (before 2023-07-02), when we construct
|
||||
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
||||
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
||||
self.inv_freq would be bf16, and the position indices are also in bf16.
|
||||
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
||||
embeddings for some positions will coincide.
|
||||
To maintain compatibility with models previously trained in pure bf16,
|
||||
we add this option.
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
|
||||
otherwise they might be in lower precision.
|
||||
This option was added because previously (before 2023-07-02), when we construct
|
||||
the position indices, we use the dtype of self.inv_freq. In most cases this would
|
||||
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
|
||||
self.inv_freq would be bf16, and the position indices are also in bf16.
|
||||
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
|
||||
embeddings for some positions will coincide.
|
||||
To maintain compatibility with models previously trained in pure bf16,
|
||||
we add this option.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
||||
/ (1.4 * dim) if scale_base is not None else None)
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base is not None
|
||||
else None
|
||||
)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device=None):
|
||||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device,
|
||||
dtype=torch.float32) / self.dim))
|
||||
|
||||
return 1.0 / (
|
||||
self.base
|
||||
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())):
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
||||
- seqlen // 2) / self.scale_base)
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
||||
- seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None,
|
||||
seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: int = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
|
||||
else it's just q of shape (batch, seqlen, nheads, headdim)
|
||||
@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
if kv is None:
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
None, None, self.interleaved
|
||||
qkv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
None,
|
||||
None,
|
||||
self.interleaved,
|
||||
)
|
||||
else:
|
||||
return apply_rotary_emb_qkv_(
|
||||
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved
|
||||
qkv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self._cos_k_cached[seqlen_offset:],
|
||||
self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
)
|
||||
else:
|
||||
q = qkv
|
||||
q = apply_rotary_emb_func(
|
||||
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self.interleaved, True
|
||||
q,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
True,
|
||||
)
|
||||
if self.scale is None:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self.interleaved
|
||||
kv,
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
)
|
||||
else:
|
||||
kv = apply_rotary_emb_kv_(
|
||||
kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved
|
||||
kv,
|
||||
self._cos_k_cached[seqlen_offset:],
|
||||
self._sin_k_cached[seqlen_offset:],
|
||||
self.interleaved,
|
||||
)
|
||||
return q, kv
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import xentropy_cuda_lib
|
||||
|
||||
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||
@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False,
|
||||
process_group=None):
|
||||
def forward(
|
||||
ctx,
|
||||
logits,
|
||||
labels,
|
||||
smoothing=0.0,
|
||||
ignored_index=-100,
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
):
|
||||
"""
|
||||
logits: (batch, vocab_size)
|
||||
labels: (batch,)
|
||||
@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
|
||||
if world_size == 1:
|
||||
losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing)
|
||||
losses.masked_fill_(labels==ignored_index, 0)
|
||||
losses.masked_fill_(labels == ignored_index, 0)
|
||||
labels_local = labels
|
||||
else:
|
||||
rank = torch.distributed.get_rank(process_group)
|
||||
@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
|
||||
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
|
||||
# last dimension of the input tensor.
|
||||
losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing,
|
||||
world_size * vocab_size)
|
||||
losses, lse_local = xentropy_cuda_lib.forward(
|
||||
logits, labels_local, smoothing, world_size * vocab_size
|
||||
)
|
||||
assert lse_local.shape == (batch,)
|
||||
assert losses.shape == (batch,)
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
# For labels not in the vocab of this partition, losses contains
|
||||
# 0.1 * (lse_local - sum logit / total_classes).
|
||||
|
||||
lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype,
|
||||
device=lse_local.device)
|
||||
torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(),
|
||||
group=process_group)
|
||||
lse_allgather = torch.empty(
|
||||
world_size, batch, dtype=lse_local.dtype, device=lse_local.device
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
lse_allgather, lse_local.contiguous(), group=process_group
|
||||
)
|
||||
handle_losses = torch.distributed.all_reduce(
|
||||
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
|
||||
)
|
||||
@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
# If there's smoothing=0.1, the total losses are
|
||||
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
|
||||
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
|
||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor')
|
||||
lse_local = lse_allgather[rank_per_sample,
|
||||
torch.arange(batch, device=lse_allgather.device)]
|
||||
rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor")
|
||||
lse_local = lse_allgather[
|
||||
rank_per_sample, torch.arange(batch, device=lse_allgather.device)
|
||||
]
|
||||
|
||||
handle_losses.wait()
|
||||
if smoothing == 0.0:
|
||||
losses += lse - lse_local
|
||||
else:
|
||||
losses += ((1 - smoothing) * (lse - lse_local)
|
||||
+ smoothing * (lse - lse_allgather.sum(dim=0)))
|
||||
losses += (1 - smoothing) * (lse - lse_local) + smoothing * (
|
||||
lse - lse_allgather.sum(dim=0)
|
||||
)
|
||||
losses.masked_fill_(ignored_mask, 0)
|
||||
|
||||
ctx.save_for_backward(logits, lse, labels_local)
|
||||
@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
|
||||
def backward(ctx, grad_loss):
|
||||
logits, lse, labels = ctx.saved_tensors
|
||||
grad_loss = grad_loss.contiguous()
|
||||
grad_loss.masked_fill_(labels==ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels,
|
||||
ctx.smoothing, ctx.inplace_backward,
|
||||
ctx.total_classes)
|
||||
grad_loss.masked_fill_(labels == ctx.ignored_index, 0)
|
||||
grad_logits = xentropy_cuda_lib.backward(
|
||||
grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes
|
||||
)
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
|
||||
def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
|
||||
inplace_backward=False, process_group=None):
|
||||
def __init__(
|
||||
self,
|
||||
ignore_index=-100,
|
||||
reduction="mean",
|
||||
label_smoothing=0.0,
|
||||
inplace_backward=False,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
if reduction not in ['mean', 'none']:
|
||||
if reduction not in ["mean", "none"]:
|
||||
raise NotImplementedError("Only support reduction = 'mean' or 'none'")
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
|
||||
assert input.is_cuda and target.is_cuda
|
||||
# SoftmaxCrossEntropyLoss implicitly casts to float
|
||||
loss = SoftmaxCrossEntropyLossFn.apply(
|
||||
input, target, self.label_smoothing, self.ignore_index, self.inplace_backward,
|
||||
self.process_group
|
||||
input,
|
||||
target,
|
||||
self.label_smoothing,
|
||||
self.ignore_index,
|
||||
self.inplace_backward,
|
||||
self.process_group,
|
||||
)
|
||||
if self.reduction == 'mean':
|
||||
if self.reduction == "mean":
|
||||
return loss.sum() / (target != self.ignore_index).sum()
|
||||
else:
|
||||
return loss
|
||||
|
||||
+242
-153
@@ -5,29 +5,32 @@
|
||||
|
||||
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
||||
|
||||
import re
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from collections.abc import Sequence
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import BertConfig
|
||||
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
|
||||
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
|
||||
|
||||
from einops import rearrange
|
||||
from transformers import BertConfig
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BertForPreTrainingOutput,
|
||||
)
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.bert_padding import (
|
||||
index_first_axis,
|
||||
index_first_axis_residual,
|
||||
pad_input,
|
||||
unpad_input,
|
||||
)
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import FusedMLP, Mlp
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
try:
|
||||
@@ -50,48 +53,63 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
use_flash_attn = getattr(config, "use_flash_attn", False)
|
||||
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
||||
rotary_kwargs = {}
|
||||
if config.position_embedding_type == "rotary":
|
||||
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
|
||||
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
||||
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
|
||||
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
|
||||
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
|
||||
dropout=config.attention_probs_dropout_prob, causal=False,
|
||||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
|
||||
return_residual=return_residual, **rotary_kwargs)
|
||||
mixer_cls = partial(
|
||||
MHA,
|
||||
num_heads=config.num_attention_heads,
|
||||
cross_attn=cross_attn,
|
||||
dropout=config.attention_probs_dropout_prob,
|
||||
causal=False,
|
||||
fused_bias_fc=fused_bias_fc,
|
||||
use_flash_attn=use_flash_attn,
|
||||
return_residual=return_residual,
|
||||
**rotary_kwargs,
|
||||
)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
||||
inner_dim = config.intermediate_size
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
fused_mlp = getattr(config, "fused_mlp", False)
|
||||
if fused_mlp:
|
||||
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
|
||||
'supports approximate gelu')
|
||||
assert config.hidden_act in ["gelu_new", "gelu_fast"], (
|
||||
"fused_mlp only " "supports approximate gelu"
|
||||
)
|
||||
if not fused_mlp:
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate=approximate),
|
||||
return_residual=return_residual)
|
||||
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
|
||||
mlp_cls = partial(
|
||||
Mlp,
|
||||
hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate=approximate),
|
||||
return_residual=return_residual,
|
||||
)
|
||||
else:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
raise ImportError("fused_dense is not installed")
|
||||
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
|
||||
mlp_cls = partial(
|
||||
FusedMLP,
|
||||
hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
return_residual=return_residual,
|
||||
)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(config, layer_idx=None):
|
||||
last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
||||
last_layer_subset = getattr(config, "last_layer_subset", False)
|
||||
cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
||||
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
||||
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
||||
# one layer) so we just choose not to return residual in this case.
|
||||
@@ -99,11 +117,17 @@ def create_block(config, layer_idx=None):
|
||||
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
|
||||
resid_dropout2=config.hidden_dropout_prob,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
return_residual=return_residual)
|
||||
block = Block(
|
||||
config.hidden_size,
|
||||
mixer_cls,
|
||||
mlp_cls,
|
||||
norm_cls=norm_cls,
|
||||
prenorm=False,
|
||||
resid_dropout1=config.hidden_dropout_prob,
|
||||
resid_dropout2=config.hidden_dropout_prob,
|
||||
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
||||
return_residual=return_residual,
|
||||
)
|
||||
return block
|
||||
|
||||
|
||||
@@ -120,12 +144,12 @@ def _init_weights(module, initializer_range=0.02):
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
self.use_flash_attn = getattr(config, "use_flash_attn", False)
|
||||
self.layers = nn.ModuleList(
|
||||
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
||||
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
||||
@@ -133,8 +157,9 @@ class BertEncoder(nn.Module):
|
||||
subset_mask: (batch, seqlen), dtype=torch.bool
|
||||
"""
|
||||
if key_padding_mask is None or not self.use_flash_attn:
|
||||
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
|
||||
if key_padding_mask is not None else None)
|
||||
mixer_kwargs = (
|
||||
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
||||
)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if subset_mask is not None:
|
||||
@@ -144,7 +169,7 @@ class BertEncoder(nn.Module):
|
||||
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
hidden_states, key_padding_mask
|
||||
)
|
||||
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
|
||||
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
||||
if subset_mask is None:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
@@ -153,33 +178,40 @@ class BertEncoder(nn.Module):
|
||||
for layer in self.layers[:-1]:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if key_padding_mask is not None:
|
||||
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
|
||||
subset_idx = torch.nonzero(
|
||||
subset_mask[key_padding_mask], as_tuple=False
|
||||
).flatten()
|
||||
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
||||
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
|
||||
dtype=torch.torch.int32), (1, 0))
|
||||
subset_cu_seqlens = F.pad(
|
||||
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
else:
|
||||
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
||||
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
||||
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
|
||||
dtype=torch.torch.int32), (1, 0))
|
||||
subset_cu_seqlens = F.pad(
|
||||
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
hidden_states_subset, hidden_states = index_first_axis_residual(
|
||||
hidden_states, subset_idx
|
||||
)
|
||||
# It's ok to set max_seqlen_q to be much larger
|
||||
mixer_kwargs = {'x_kv': hidden_states,
|
||||
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
|
||||
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
|
||||
mixer_kwargs = {
|
||||
"x_kv": hidden_states,
|
||||
"cu_seqlens": subset_cu_seqlens,
|
||||
"max_seqlen": max_seqlen_in_batch,
|
||||
"cu_seqlens_k": cu_seqlens,
|
||||
"max_seqlen_k": max_seqlen_in_batch,
|
||||
}
|
||||
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
raise ImportError("fused_dense is not installed")
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
@@ -194,18 +226,17 @@ class BertPooler(nn.Module):
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
raise ImportError("fused_dense is not installed")
|
||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
|
||||
self.transform_act_fn = nn.GELU(approximate=approximate)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -215,18 +246,18 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
|
||||
self.layer_norm.eps)
|
||||
hidden_states = layer_norm(
|
||||
hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
raise ImportError("fused_dense is not installed")
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
@@ -254,9 +285,10 @@ class BertPreTrainingHeads(nn.Module):
|
||||
|
||||
|
||||
class BertPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
@@ -265,7 +297,8 @@ class BertPreTrainedModel(nn.Module):
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
)
|
||||
)
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
@@ -287,28 +320,33 @@ class BertPreTrainedModel(nn.Module):
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
|
||||
config), strict=False)
|
||||
load_return = model.load_state_dict(
|
||||
remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
|
||||
)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
|
||||
class BertModel(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BertConfig, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
config.vocab_size += (self.pad_vocab_size_multiple
|
||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
config.vocab_size += self.pad_vocab_size_multiple - (
|
||||
config.vocab_size % self.pad_vocab_size_multiple
|
||||
)
|
||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast"]
|
||||
|
||||
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
|
||||
config.max_position_embeddings, config.type_vocab_size,
|
||||
padding_idx=config.pad_token_id)
|
||||
self.embeddings = BertEmbeddings(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
config.max_position_embeddings,
|
||||
config.type_vocab_size,
|
||||
padding_idx=config.pad_token_id,
|
||||
)
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.encoder = BertEncoder(config)
|
||||
@@ -316,36 +354,46 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
|
||||
masked_tokens_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
masked_tokens_mask=None,
|
||||
):
|
||||
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
||||
we only want the output for the masked tokens. This means that we only compute the last
|
||||
layer output for these tokens.
|
||||
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
||||
"""
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
hidden_states = self.embeddings(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
||||
)
|
||||
# TD [2022-12:18]: Don't need to force residual in fp32
|
||||
# BERT puts embedding LayerNorm before embedding dropout.
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.emb_ln(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
|
||||
self.emb_ln.eps)
|
||||
hidden_states = layer_norm(
|
||||
hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps
|
||||
)
|
||||
hidden_states = self.emb_drop(hidden_states)
|
||||
|
||||
if masked_tokens_mask is not None:
|
||||
batch_size, seqlen = input_ids.shape[:2]
|
||||
# We also need the first column for the CLS token
|
||||
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
|
||||
device=input_ids.device)
|
||||
first_col_mask = torch.zeros(
|
||||
batch_size, seqlen, dtype=torch.bool, device=input_ids.device
|
||||
)
|
||||
first_col_mask[:, 0] = True
|
||||
subset_mask = masked_tokens_mask | first_col_mask
|
||||
else:
|
||||
subset_mask = None
|
||||
|
||||
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
|
||||
subset_mask=subset_mask)
|
||||
sequence_output = self.encoder(
|
||||
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
||||
)
|
||||
|
||||
if masked_tokens_mask is None:
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
@@ -358,8 +406,7 @@ class BertModel(BertPreTrainedModel):
|
||||
else:
|
||||
pool_input = sequence_output[first_col_mask[subset_mask]]
|
||||
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
||||
pooled_output = (self.pooler(pool_input, pool=False)
|
||||
if self.pooler is not None else None)
|
||||
pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
@@ -368,22 +415,24 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__(config)
|
||||
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
||||
# (around 15%) to the classifier heads.
|
||||
self.dense_seq_output = getattr(config, 'dense_seq_output', False)
|
||||
self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
||||
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
||||
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
||||
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
||||
if self.last_layer_subset:
|
||||
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
|
||||
use_xentropy = getattr(config, 'use_xentropy', False)
|
||||
assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
||||
use_xentropy = getattr(config, "use_xentropy", False)
|
||||
if use_xentropy and CrossEntropyLoss is None:
|
||||
raise ImportError('xentropy_cuda is not installed')
|
||||
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
|
||||
else partial(CrossEntropyLoss, inplace_backward=True))
|
||||
raise ImportError("xentropy_cuda is not installed")
|
||||
loss_cls = (
|
||||
nn.CrossEntropyLoss
|
||||
if not use_xentropy
|
||||
else partial(CrossEntropyLoss, inplace_backward=True)
|
||||
)
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
@@ -397,8 +446,15 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
def tie_weights(self):
|
||||
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
|
||||
labels=None, next_sentence_label=None):
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
next_sentence_label=None,
|
||||
):
|
||||
"""
|
||||
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
||||
mask).
|
||||
@@ -414,28 +470,38 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
"""
|
||||
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
||||
outputs = self.bert(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
||||
masked_tokens_mask=masked_tokens_mask
|
||||
masked_tokens_mask=masked_tokens_mask,
|
||||
)
|
||||
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
||||
if self.dense_seq_output and labels is not None:
|
||||
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
||||
if not self.last_layer_subset:
|
||||
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
|
||||
masked_token_idx)
|
||||
sequence_output = index_first_axis(
|
||||
rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
||||
)
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
total_loss = None
|
||||
if labels is not None and next_sentence_label is not None:
|
||||
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
|
||||
masked_lm_loss = self.mlm_loss(prediction_scores,
|
||||
labels.flatten()[masked_token_idx])
|
||||
if (
|
||||
self.dense_seq_output and labels is not None
|
||||
): # prediction_scores are already flattened
|
||||
masked_lm_loss = self.mlm_loss(
|
||||
prediction_scores, labels.flatten()[masked_token_idx]
|
||||
)
|
||||
else:
|
||||
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
|
||||
rearrange(labels, '... -> (...)'))
|
||||
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
|
||||
rearrange(next_sentence_label, '... -> (...)'))
|
||||
masked_lm_loss = self.mlm_loss(
|
||||
rearrange(prediction_scores, "... v -> (...) v"),
|
||||
rearrange(labels, "... -> (...)"),
|
||||
)
|
||||
next_sentence_loss = self.nsp_loss(
|
||||
rearrange(seq_relationship_score, "... t -> (...) t"),
|
||||
rearrange(next_sentence_label, "... -> (...)"),
|
||||
)
|
||||
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
||||
|
||||
return BertForPreTrainingOutput(
|
||||
@@ -448,83 +514,106 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
def remap_state_dict(state_dict, config):
|
||||
# LayerNorm
|
||||
def key_mapping_ln_gamma_beta(key):
|
||||
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
|
||||
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
|
||||
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
||||
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Layers
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
|
||||
return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.norm1.\2', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.norm2.\2', key)
|
||||
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
|
||||
r'cls.predictions.transform.layer_norm.\1', key)
|
||||
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
|
||||
key = re.sub(
|
||||
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.norm1.\2",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.norm2.\2",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
||||
r"cls.predictions.transform.layer_norm.\1",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mlp.fc1.\2', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mlp.fc2.\2', key)
|
||||
key = re.sub(
|
||||
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.mlp.fc1.\2",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.mlp.fc2.\2",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
last_layer_subset = getattr(config, "last_layer_subset", False)
|
||||
for d in range(config.num_hidden_layers):
|
||||
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
|
||||
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
|
||||
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
|
||||
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
|
||||
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
|
||||
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
|
||||
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
|
||||
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
|
||||
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
|
||||
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
|
||||
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
||||
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
||||
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
|
||||
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
||||
[Wq, Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
|
||||
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
||||
else:
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
|
||||
[Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
|
||||
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
||||
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
||||
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
||||
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
|
||||
return re.sub(
|
||||
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
||||
r"bert.encoder.layers.\1.mixer.out_proj.\2",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_decoder_bias(key):
|
||||
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
|
||||
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Word embedding
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
if pad_vocab_size_multiple > 1:
|
||||
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
|
||||
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
||||
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
decoder_weight = state_dict['cls.predictions.decoder.weight']
|
||||
state_dict['cls.predictions.decoder.weight'] = F.pad(
|
||||
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
||||
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
||||
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
||||
)
|
||||
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
||||
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
||||
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
||||
decoder_bias = state_dict['cls.predictions.decoder.bias']
|
||||
state_dict['cls.predictions.decoder.bias'] = F.pad(
|
||||
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
||||
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
||||
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
||||
)
|
||||
|
||||
|
||||
+58
-37
@@ -2,93 +2,114 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, FalconConfig
|
||||
from transformers import FalconConfig, GPT2Config
|
||||
|
||||
|
||||
def remap_state_dict_hf_falcon(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
|
||||
return re.sub(r"^transformer.h.", "transformer.layers.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
|
||||
return re.sub(
|
||||
r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
output_embeddings = state_dict.pop("lm_head.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
state_dict["lm_head.weight"] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
output_embeddings_bias = state_dict.pop('lm_head.bias')
|
||||
state_dict['lm_head.bias'] = F.pad(
|
||||
output_embeddings_bias = state_dict.pop("lm_head.bias")
|
||||
state_dict["lm_head.bias"] = F.pad(
|
||||
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).post_attention_layernorm.",
|
||||
r"transformer.layers.\1.norm2.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key)
|
||||
key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
|
||||
r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
|
||||
r'transformer.layers.\1.mixer.Wqkv.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).self_attention.query_key_value.",
|
||||
r"transformer.layers.\1.mixer.Wqkv.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).self_attention.dense.",
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", 1)
|
||||
headdim = config.hidden_size // n_head
|
||||
for l in range(config.n_layer):
|
||||
# The weights are stored in a different layout compared to our implementation
|
||||
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
|
||||
"(group ratio headdim) ... -> group ratio headdim ...",
|
||||
ratio=n_head // n_head_kv + 2, headdim=headdim)
|
||||
Wqkv = rearrange(
|
||||
state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"),
|
||||
"(group ratio headdim) ... -> group ratio headdim ...",
|
||||
ratio=n_head // n_head_kv + 2,
|
||||
headdim=headdim,
|
||||
)
|
||||
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
|
||||
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
|
||||
n_head_kv = getattr(falcon_config, "n_head_kv",
|
||||
1 if getattr(falcon_config, "multi_query", False)
|
||||
else falcon_config.n_head)
|
||||
n_head_kv = getattr(
|
||||
falcon_config,
|
||||
"n_head_kv",
|
||||
1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head,
|
||||
)
|
||||
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
|
||||
# So we have to infer it from the number of heads in the key/value block
|
||||
parallel_block_tied_norm = n_head_kv == 1
|
||||
|
||||
+20
-10
@@ -11,6 +11,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import GPT2Config
|
||||
|
||||
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
||||
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
||||
@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
|
||||
ParallelMLP,
|
||||
)
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank
|
||||
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import GPT2Config
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
||||
@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size
|
||||
state_dict[key] = x[rank * dim: (rank + 1) * dim]
|
||||
state_dict[key] = x[rank * dim : (rank + 1) * dim]
|
||||
|
||||
def shard_last_dim(state_dict, key, multiple_of=1):
|
||||
if key in state_dict:
|
||||
@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size // 2
|
||||
state_dict[key] = rearrange(
|
||||
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim],
|
||||
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim],
|
||||
"two o ... -> (two o) ...",
|
||||
)
|
||||
|
||||
def shard_qkv_headdim(state_dict, key):
|
||||
if key in state_dict:
|
||||
n_head_each_rank = [
|
||||
get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size)
|
||||
get_dim_for_local_rank(n_head, world_size, local_rank)
|
||||
for local_rank in range(world_size)
|
||||
]
|
||||
n_head_kv_each_rank = [
|
||||
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size)
|
||||
get_dim_for_local_rank(n_head_kv, world_size, local_rank)
|
||||
for local_rank in range(world_size)
|
||||
]
|
||||
|
||||
beg_n_head = sum(n_head_each_rank[:rank])
|
||||
@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
if n_head_kv == n_head:
|
||||
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
|
||||
state_dict[key] = rearrange(
|
||||
x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..."
|
||||
x[:, beg_n_head * head_dim : end_n_head * head_dim],
|
||||
"three d ... -> (three d) ...",
|
||||
)
|
||||
else:
|
||||
x = rearrange(
|
||||
@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
torch.cat(
|
||||
[
|
||||
x[beg_n_head:end_n_head],
|
||||
x[n_head + beg_n_head_kv: n_head + end_n_head_kv],
|
||||
x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv],
|
||||
x[n_head + beg_n_head_kv : n_head + end_n_head_kv],
|
||||
x[
|
||||
n_head
|
||||
+ n_head_kv
|
||||
+ beg_n_head_kv : n_head
|
||||
+ n_head_kv
|
||||
+ end_n_head_kv
|
||||
],
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
|
||||
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
|
||||
torch.cat(
|
||||
[
|
||||
x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank]
|
||||
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank]
|
||||
for x in xs
|
||||
],
|
||||
dim=0,
|
||||
|
||||
@@ -2,80 +2,100 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, GPTNeoXConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_gpt_neox(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^gpt_neox.', 'transformer.', key)
|
||||
return re.sub(r"^gpt_neox.", "transformer.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
|
||||
return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop('embed_out.weight')
|
||||
output_embeddings = state_dict.pop("embed_out.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
state_dict["lm_head.weight"] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).post_attention_layernorm.",
|
||||
r"transformer.layers.\1.norm2.",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
# We don't store these biases
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.bias')
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.bias")
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.masked_bias")
|
||||
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
||||
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
||||
headdim = config.hidden_size // config.num_attention_heads
|
||||
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
|
||||
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
|
||||
three=3, headdim=headdim
|
||||
Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight")
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange(
|
||||
Wqkv,
|
||||
"(nheads three headdim) ... -> (three nheads headdim) ...",
|
||||
three=3,
|
||||
headdim=headdim,
|
||||
)
|
||||
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
|
||||
bqkv, '(nheads three headdim) -> (three nheads headdim)',
|
||||
three=3, headdim=headdim
|
||||
bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias")
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange(
|
||||
bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim
|
||||
)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
|
||||
r'transformer.layers.\1.mixer.rotary_emb.', key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).attention.dense.",
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).attention.rotary_emb.",
|
||||
r"transformer.layers.\1.mixer.rotary_emb.",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
+36
-25
@@ -2,67 +2,78 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, GPTJConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_gptj(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
|
||||
return re.sub(r"^transformer.h.", "transformer.layers.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
|
||||
return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
output_embeddings = state_dict.pop("lm_head.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
state_dict["lm_head.weight"] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
output_embeddings_bias = state_dict.pop('lm_head.bias')
|
||||
state_dict['lm_head.bias'] = F.pad(
|
||||
output_embeddings_bias = state_dict.pop("lm_head.bias")
|
||||
state_dict["lm_head.bias"] = F.pad(
|
||||
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
|
||||
return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight")
|
||||
Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight")
|
||||
Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight")
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
# We don't store these biases
|
||||
state_dict.pop(f'transformer.layers.{l}.attn.bias')
|
||||
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
|
||||
state_dict.pop(f"transformer.layers.{l}.attn.bias")
|
||||
state_dict.pop(f"transformer.layers.{l}.attn.masked_bias")
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).attn.out_proj.",
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
+106
-70
@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
|
||||
|
||||
def remap_state_dict_meta_llama(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return f'transformer.{key}' if not key.startswith('output.') else key
|
||||
return f"transformer.{key}" if not key.startswith("output.") else key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
|
||||
return re.sub(
|
||||
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = (
|
||||
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop('output.weight')
|
||||
output_embeddings = state_dict.pop("output.weight")
|
||||
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
||||
# differently.
|
||||
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
vocab_size = (
|
||||
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple
|
||||
)
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
state_dict["lm_head.weight"] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
|
||||
)
|
||||
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for l in range(config.n_layer):
|
||||
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
|
||||
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
|
||||
w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
|
||||
w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
|
||||
# Our ordering is different
|
||||
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
|
||||
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
|
||||
Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
|
||||
Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
# We don't store these
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).attention.wo.",
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
state_dict.pop("transformer.rope.freqs", None)
|
||||
@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
|
||||
def remap_state_dict_hf_llama(state_dict, config):
|
||||
# Embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
|
||||
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = (
|
||||
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LM head
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
output_embeddings = state_dict.pop("lm_head.weight")
|
||||
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
||||
# differently.
|
||||
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
vocab_size = (
|
||||
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple
|
||||
)
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
state_dict["lm_head.weight"] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
|
||||
# Fusing weights this way based on difference in the following:
|
||||
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
|
||||
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
|
||||
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
|
||||
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
|
||||
w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
|
||||
w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
|
||||
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
|
||||
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
|
||||
key = re.sub(
|
||||
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
|
||||
def inv_permute(w):
|
||||
# Inverse of permute implemented in:
|
||||
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
|
||||
return w.reshape(
|
||||
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
|
||||
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
|
||||
return (
|
||||
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
|
||||
.transpose(1, 2)
|
||||
.reshape(config.n_embd, config.n_embd)
|
||||
)
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
|
||||
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
|
||||
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
|
||||
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
|
||||
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
|
||||
)
|
||||
# We don't store these
|
||||
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
|
||||
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
return re.sub(
|
||||
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
return state_dict
|
||||
|
||||
|
||||
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
|
||||
def config_from_meta_checkpoint(
|
||||
checkpoint_path: Union[str, os.PathLike], model_name: str
|
||||
) -> LlamaConfig:
|
||||
"""Load a LlamaConfig from a checkpoint path."""
|
||||
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
|
||||
with open(Path(checkpoint_path) / model_name / "params.json") as f:
|
||||
params = json.load(f)
|
||||
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
|
||||
num_attention_heads=params['n_heads'],
|
||||
num_hidden_layers=params['n_layers'],
|
||||
rms_norm_eps=params['norm_eps'])
|
||||
config = LlamaConfig(
|
||||
hidden_size=params["dim"],
|
||||
intermediate_size=None,
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
|
||||
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
|
||||
def config_from_hf_checkpoint(
|
||||
checkpoint_path: Union[str, os.PathLike], model_name: str
|
||||
) -> LlamaConfig:
|
||||
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
|
||||
|
||||
|
||||
def config_from_checkpoint(
|
||||
@@ -182,10 +214,14 @@ def config_from_checkpoint(
|
||||
return config_from_hf_checkpoint(checkpoint_path, model_name)
|
||||
|
||||
|
||||
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
|
||||
def state_dicts_from_checkpoint(
|
||||
checkpoint_path: Union[str, os.PathLike], model_name: str
|
||||
) -> list[dict]:
|
||||
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
||||
return [torch.load(path, map_location='cpu')
|
||||
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
|
||||
return [
|
||||
torch.load(path, map_location="cpu")
|
||||
for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
|
||||
]
|
||||
|
||||
|
||||
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
||||
@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
||||
n_layer=llama_config.num_hidden_layers,
|
||||
n_head=llama_config.num_attention_heads,
|
||||
n_inner=llama_config.intermediate_size,
|
||||
activation_function='swiglu', # Hardcode since HF calls it 'silu'
|
||||
activation_function="swiglu", # Hardcode since HF calls it 'silu'
|
||||
# Llama doesn't have dropout, idk if it's because they only release the inference code
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
|
||||
+51
-37
@@ -2,75 +2,86 @@
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, OPTConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_opt(state_dict, config):
|
||||
def key_mapping_model(key):
|
||||
key = re.sub(r'^model.decoder.', 'transformer.', key)
|
||||
key = re.sub(r"^model.decoder.", "transformer.", key)
|
||||
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
|
||||
key = re.sub(r'^decoder.', 'transformer.', key)
|
||||
key = re.sub(r"^decoder.", "transformer.", key)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_emb(key):
|
||||
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
|
||||
key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
||||
# The OPT-350m model uses has project_in and project_out
|
||||
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
|
||||
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
|
||||
key = re.sub(r'^transformer.embed_positions.',
|
||||
'transformer.embeddings.position_embeddings.', key)
|
||||
key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key)
|
||||
key = re.sub(r"^transformer.project_out.", "project_out.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
# OPT uses the first 2 indices of pos_emb for padding tokens
|
||||
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
|
||||
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight")
|
||||
state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:]
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key)
|
||||
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
|
||||
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
|
||||
r'transformer.layers.\1.mlp.fc\2.', key)
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
|
||||
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
|
||||
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
|
||||
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
|
||||
Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight")
|
||||
Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight")
|
||||
Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight")
|
||||
bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias")
|
||||
bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias")
|
||||
bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias")
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).self_attn.out_proj.",
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
|
||||
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
|
||||
assert opt_config.layerdrop == 0.0
|
||||
assert opt_config.layer_norm_elementwise_affine
|
||||
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
|
||||
else opt_config.word_embed_proj_dim)
|
||||
word_embed_proj_dim = (
|
||||
None
|
||||
if opt_config.word_embed_proj_dim == opt_config.hidden_size
|
||||
else opt_config.word_embed_proj_dim
|
||||
)
|
||||
return GPT2Config(
|
||||
vocab_size=opt_config.vocab_size,
|
||||
n_positions=opt_config.max_position_embeddings,
|
||||
@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
|
||||
eos_token_id=opt_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=opt_config.do_layer_norm_before,
|
||||
word_embed_proj_dim=word_embed_proj_dim
|
||||
word_embed_proj_dim=word_embed_proj_dim,
|
||||
)
|
||||
|
||||
@@ -10,13 +10,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from timm.models.helpers import named_apply
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from flash_attn.layers.patch_embed import PatchEmbed
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import FusedMLP, Mlp
|
||||
from timm.models.helpers import named_apply
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
|
||||
+168
-71
@@ -1,13 +1,12 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
@@ -35,11 +34,24 @@ except ImportError:
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
|
||||
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
|
||||
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mixer_cls=None,
|
||||
mlp_cls=None,
|
||||
norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout,
|
||||
prenorm=True,
|
||||
resid_dropout1=0.0,
|
||||
resid_dropout2=0.0,
|
||||
drop_path1=0.0,
|
||||
drop_path2=0.0,
|
||||
fused_dropout_add_ln=False,
|
||||
return_residual=False,
|
||||
residual_in_fp32=False,
|
||||
sequence_parallel=False,
|
||||
mark_shared_params=False,
|
||||
):
|
||||
"""
|
||||
For prenorm=True, this Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
@@ -63,26 +75,27 @@ class Block(nn.Module):
|
||||
self.return_residual = return_residual
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
if self.residual_in_fp32:
|
||||
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
|
||||
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
||||
if mixer_cls is None:
|
||||
mixer_cls = partial(MHA, num_heads=dim // 64)
|
||||
if mlp_cls is None:
|
||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.dropout1 = dropout_cls(resid_dropout1)
|
||||
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
||||
self.norm1 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
self.dropout2 = dropout_cls(resid_dropout2)
|
||||
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed'
|
||||
assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed'
|
||||
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
|
||||
and isinstance(self.dropout1, nn.Dropout))
|
||||
assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed"
|
||||
assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
|
||||
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
||||
self.dropout1, nn.Dropout
|
||||
)
|
||||
|
||||
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
||||
# then the input to each worker in the tensor parallel group will be different.
|
||||
@@ -94,22 +107,27 @@ class Block(nn.Module):
|
||||
if sequence_parallel:
|
||||
for p in self.norm1.parameters():
|
||||
p._sequence_parallel = True
|
||||
if hasattr(self, 'norm2'):
|
||||
if hasattr(self, "norm2"):
|
||||
for p in self.norm2.parameters():
|
||||
p._sequence_parallel = True
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
if mark_shared_params:
|
||||
for p in self.norm1.parameters():
|
||||
p._shared_params = True
|
||||
if hasattr(self, 'norm2'):
|
||||
if hasattr(self, "norm2"):
|
||||
for p in self.norm2.parameters():
|
||||
p._shared_params = True
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
|
||||
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None,
|
||||
mixer_subset=None, mixer_kwargs=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
residual: Optional[Tensor] = None,
|
||||
mixer_subset=None,
|
||||
mixer_kwargs=None,
|
||||
):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
@@ -119,8 +137,11 @@ class Block(nn.Module):
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
"""
|
||||
fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm)
|
||||
fused_add_norm_fn = (
|
||||
dropout_add_rms_norm
|
||||
if RMSNorm and isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm
|
||||
)
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_path1(self.dropout1(hidden_states))
|
||||
@@ -132,19 +153,28 @@ class Block(nn.Module):
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
rowscale1 = self.drop_path1(
|
||||
torch.ones(
|
||||
hidden_states.shape[:-1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
hidden_states,
|
||||
residual,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0,
|
||||
self.norm1.eps,
|
||||
rowscale=rowscale1,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
if mixer_subset is not None:
|
||||
mixer_kwargs['mixer_subset'] = mixer_subset
|
||||
mixer_kwargs["mixer_subset"] = mixer_subset
|
||||
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||
if mixer_subset is not None:
|
||||
residual = residual[:, mixer_subset]
|
||||
@@ -159,14 +189,23 @@ class Block(nn.Module):
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
rowscale2 = self.drop_path2(
|
||||
torch.ones(
|
||||
hidden_states.shape[:-1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
hidden_states,
|
||||
residual,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0,
|
||||
self.norm2.eps,
|
||||
rowscale=rowscale2,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
@@ -178,38 +217,58 @@ class Block(nn.Module):
|
||||
if self.return_residual: # mixer out is actually a pair here
|
||||
mixer_out, hidden_states = mixer_out
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out))
|
||||
+ hidden_states).to(dtype=self.norm1.weight.dtype))
|
||||
hidden_states = self.norm1(
|
||||
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
|
||||
dtype=self.norm1.weight.dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
|
||||
rowscale1 = self.drop_path1(
|
||||
torch.ones(
|
||||
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
|
||||
)
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
mixer_out, hidden_states, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
rowscale=rowscale1, prenorm=False
|
||||
mixer_out,
|
||||
hidden_states,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0,
|
||||
self.norm1.eps,
|
||||
rowscale=rowscale1,
|
||||
prenorm=False,
|
||||
)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
if self.return_residual: # mlp out is actually a pair here
|
||||
mlp_out, hidden_states = mlp_out
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out))
|
||||
+ hidden_states).to(dtype=self.norm2.weight.dtype))
|
||||
hidden_states = self.norm2(
|
||||
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
|
||||
dtype=self.norm2.weight.dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
|
||||
rowscale2 = self.drop_path2(
|
||||
torch.ones(
|
||||
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
|
||||
)
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
mlp_out, hidden_states, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
rowscale=rowscale2, prenorm=False
|
||||
mlp_out,
|
||||
hidden_states,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0,
|
||||
self.norm2.eps,
|
||||
rowscale=rowscale2,
|
||||
prenorm=False,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@@ -219,10 +278,21 @@ class ParallelBlock(nn.Module):
|
||||
and PaLM.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0.,
|
||||
tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False,
|
||||
sequence_parallel=False, mark_shared_params=False):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
mixer_cls=None,
|
||||
mlp_cls=None,
|
||||
norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout,
|
||||
resid_dropout1=0.0,
|
||||
resid_dropout2=0.0,
|
||||
tied_norm=False,
|
||||
fused_dropout_add_ln=False,
|
||||
residual_in_fp32=False,
|
||||
sequence_parallel=False,
|
||||
mark_shared_params=False,
|
||||
):
|
||||
"""
|
||||
This Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
@@ -250,10 +320,15 @@ class ParallelBlock(nn.Module):
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
|
||||
assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed'
|
||||
assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm))
|
||||
and isinstance(self.dropout1, nn.Dropout))
|
||||
assert (
|
||||
dropout_add_layer_norm_parallel_residual is not None
|
||||
), "dropout_layer_norm is not installed"
|
||||
assert (
|
||||
dropout_add_rms_norm_parallel_residual is not None
|
||||
), "dropout_layer_norm is not installed"
|
||||
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
||||
self.dropout1, nn.Dropout
|
||||
)
|
||||
|
||||
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
||||
# then the input to each worker in the tensor parallel group will be different.
|
||||
@@ -265,22 +340,27 @@ class ParallelBlock(nn.Module):
|
||||
if sequence_parallel:
|
||||
for p in self.norm1.parameters():
|
||||
p._sequence_parallel = True
|
||||
if hasattr(self, 'norm2'):
|
||||
if hasattr(self, "norm2"):
|
||||
for p in self.norm2.parameters():
|
||||
p._sequence_parallel = True
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
if mark_shared_params:
|
||||
for p in self.norm1.parameters():
|
||||
p._shared_params = True
|
||||
if hasattr(self, 'norm2'):
|
||||
if hasattr(self, "norm2"):
|
||||
for p in self.norm2.parameters():
|
||||
p._shared_params = True
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
|
||||
def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None,
|
||||
residual: Optional[Tensor] = None, mixer_kwargs=None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states1: Tensor,
|
||||
hidden_states2: Optional[Tensor] = None,
|
||||
residual: Optional[Tensor] = None,
|
||||
mixer_kwargs=None,
|
||||
):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
@@ -290,30 +370,47 @@ class ParallelBlock(nn.Module):
|
||||
"""
|
||||
# TODO: Ideally we should only do the allgather / allreduce once for
|
||||
# the Linear to MLP & Attention
|
||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual)
|
||||
fused_add_norm_fn = (
|
||||
dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual
|
||||
)
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped1 = self.dropout1(hidden_states1)
|
||||
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
||||
if hidden_states2 is not None:
|
||||
dropped2 = self.dropout2(hidden_states2)
|
||||
residual = ((residual + dropped1 + dropped2)
|
||||
if residual is not None else dropped1 + dropped2)
|
||||
residual = (
|
||||
(residual + dropped1 + dropped2)
|
||||
if residual is not None
|
||||
else dropped1 + dropped2
|
||||
)
|
||||
else:
|
||||
residual = (residual + dropped1) if residual is not None else dropped1
|
||||
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
||||
hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
||||
if not self.tied_norm else hidden_states1)
|
||||
hidden_states2 = (
|
||||
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
||||
if not self.tied_norm
|
||||
else hidden_states1
|
||||
)
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
weight2, bias2 = ((self.norm2.weight, self.norm2.bias)
|
||||
if not self.tied_norm else (None, None))
|
||||
weight2, bias2 = (
|
||||
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
|
||||
)
|
||||
hidden_states1, hidden_states2, residual = fused_add_norm_fn(
|
||||
hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias,
|
||||
weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
hidden_states1,
|
||||
hidden_states2,
|
||||
residual,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
weight2,
|
||||
bias2,
|
||||
self.dropout1.p if self.training else 0.0,
|
||||
self.norm1.eps,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
if self.tied_norm:
|
||||
hidden_states2 = hidden_states1
|
||||
|
||||
@@ -2,42 +2,52 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||
|
||||
|
||||
class GPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
|
||||
word_embed_proj_dim=None, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
padding_idx=None,
|
||||
word_embed_proj_dim=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
|
||||
the project up to embed_dim
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
|
||||
the project up to embed_dim
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if word_embed_proj_dim is None:
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.word_embeddings = nn.Embedding(
|
||||
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
|
||||
)
|
||||
self.project_in = None
|
||||
else:
|
||||
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
|
||||
padding_idx=padding_idx, **factory_kwargs)
|
||||
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
|
||||
**factory_kwargs)
|
||||
self.word_embeddings = nn.Embedding(
|
||||
vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
|
||||
)
|
||||
self.project_in = nn.Linear(
|
||||
word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
**factory_kwargs)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
max_position_embeddings, embed_dim, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
@@ -52,31 +62,39 @@ class GPT2Embeddings(nn.Module):
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size,
|
||||
padding_idx=None, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
type_vocab_size,
|
||||
padding_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If type_vocab_size <= 0, there's no token type embeddings
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If type_vocab_size <= 0, there's no token type embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.word_embeddings = nn.Embedding(
|
||||
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
**factory_kwargs)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
max_position_embeddings, embed_dim, **factory_kwargs
|
||||
)
|
||||
if self.type_vocab_size > 0:
|
||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim,
|
||||
**factory_kwargs)
|
||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
token_type_ids: (batch, seqlen)
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
token_type_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
@@ -94,16 +112,17 @@ class BertEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if num_embeddings % world_size != 0:
|
||||
raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
raise ValueError(
|
||||
f"num_embeddings ({num_embeddings}) must be divisible by "
|
||||
f"world_size ({world_size})"
|
||||
)
|
||||
if world_size > 1 and padding_idx is not None:
|
||||
raise RuntimeError('ParallelEmbedding does not support padding_idx')
|
||||
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
|
||||
@@ -125,33 +144,45 @@ class VocabParallelEmbedding(nn.Embedding):
|
||||
|
||||
|
||||
class ColumnParallelEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
|
||||
self.process_group = process_group
|
||||
if process_group is not None:
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
if embedding_dim % world_size != 0:
|
||||
raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
|
||||
f'world_size ({world_size})')
|
||||
raise ValueError(
|
||||
f"embedding_dim ({embedding_dim}) must be divisible by "
|
||||
f"world_size ({world_size})"
|
||||
)
|
||||
else:
|
||||
world_size = 1
|
||||
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
||||
|
||||
|
||||
class ParallelGPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
|
||||
padding_idx=None, sequence_parallel=True, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
process_group,
|
||||
padding_idx=None,
|
||||
sequence_parallel=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
|
||||
**factory_kwargs
|
||||
vocab_size,
|
||||
embed_dim,
|
||||
padding_idx=padding_idx,
|
||||
process_group=process_group,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
@@ -161,8 +192,8 @@ class ParallelGPT2Embeddings(nn.Module):
|
||||
|
||||
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
||||
"""
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
input_ids: (batch, seqlen)
|
||||
position_ids: (batch, seqlen)
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
world_size = torch.distributed.get_world_size(self.process_group)
|
||||
@@ -176,8 +207,10 @@ class ParallelGPT2Embeddings(nn.Module):
|
||||
else:
|
||||
partition_dim = self.position_embeddings.embedding_dim
|
||||
rank = torch.distributed.get_rank(self.process_group)
|
||||
embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
|
||||
embeddings[
|
||||
..., rank * partition_dim : (rank + 1) * partition_dim
|
||||
] += position_embeddings
|
||||
if combine_batch_seqlen_dim:
|
||||
embeddings = rearrange(embeddings, 'b s d -> (b s) d')
|
||||
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
||||
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
||||
return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
||||
|
||||
@@ -732,8 +732,12 @@ class ParallelMHA(nn.Module):
|
||||
self.num_heads % self.num_heads_kv == 0
|
||||
), "num_heads must be divisible by num_heads_kv"
|
||||
|
||||
self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
|
||||
self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
|
||||
self.num_heads_per_rank = get_dim_for_local_rank(
|
||||
self.num_heads, self.world_size, self.local_rank
|
||||
)
|
||||
self.num_heads_kv_per_rank = get_dim_for_local_rank(
|
||||
self.num_heads, self.world_size, self.local_rank
|
||||
)
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||
|
||||
|
||||
+96
-32
@@ -17,10 +17,19 @@ except ImportError:
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
|
||||
bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
activation=F.gelu,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
return_residual=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
out_features = out_features if out_features is not None else in_features
|
||||
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
||||
@@ -37,21 +46,42 @@ class Mlp(nn.Module):
|
||||
|
||||
|
||||
class ParallelMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
|
||||
process_group: ProcessGroup = None, sequence_parallel=True,
|
||||
bias1=True, bias2=True, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
activation=F.gelu,
|
||||
process_group: ProcessGroup = None,
|
||||
sequence_parallel=True,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
||||
assert RowParallelLinear is not None, "Need to install fused_dense"
|
||||
out_features = out_features if out_features is not None else in_features
|
||||
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
||||
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
process_group,
|
||||
bias=bias1,
|
||||
sequence_parallel=sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.activation = activation
|
||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
self.fc2 = RowParallelLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias2,
|
||||
sequence_parallel=sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
@@ -61,15 +91,25 @@ class ParallelMLP(nn.Module):
|
||||
|
||||
|
||||
class GatedMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
|
||||
bias1=True, bias2=True, multiple_of=256, return_residual=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
activation=F.sigmoid,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
multiple_of=256,
|
||||
return_residual=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
out_features = out_features if out_features is not None else in_features
|
||||
hidden_features = (hidden_features if hidden_features is not None
|
||||
else int(8 * in_features / 3))
|
||||
hidden_features = (
|
||||
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
||||
)
|
||||
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
||||
self.return_residual = return_residual
|
||||
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
||||
@@ -88,24 +128,48 @@ class GatedMlp(nn.Module):
|
||||
|
||||
|
||||
class ParallelGatedMlp(nn.Module):
|
||||
""" Parallel GatedMlp """
|
||||
"""Parallel GatedMlp"""
|
||||
|
||||
def __init__(self, in_features, process_group, hidden_features=None, out_features=None,
|
||||
activation=F.sigmoid, bias1=True, bias2=True, multiple_of=256,
|
||||
sequence_parallel=True, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
process_group,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
activation=F.sigmoid,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
multiple_of=256,
|
||||
sequence_parallel=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
out_features = out_features if out_features is not None else in_features
|
||||
hidden_features = (hidden_features if hidden_features is not None
|
||||
else int(8 * in_features / 3))
|
||||
hidden_features = (
|
||||
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
||||
)
|
||||
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
||||
if ColumnParallelLinear is None or RowParallelLinear is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, bias=bias1,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
raise ImportError("fused_dense is not installed")
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
2 * hidden_features,
|
||||
process_group,
|
||||
bias=bias1,
|
||||
sequence_parallel=sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.activation = activation
|
||||
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
self.fc2 = RowParallelLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
process_group,
|
||||
bias=bias2,
|
||||
sequence_parallel=sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc1(x)
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
@@ -18,17 +17,19 @@ def bias_gelu(y, bias):
|
||||
x = bias + y
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
|
||||
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, y, bias):
|
||||
"""Assume that y has shape (B, D) and bias has shape (D)
|
||||
"""
|
||||
"""Assume that y has shape (B, D) and bias has shape (D)"""
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
||||
1 + tanh_out
|
||||
)
|
||||
grad_y = ff * g
|
||||
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
|
||||
|
||||
@@ -56,6 +57,7 @@ bias_gelu_impl = GeLUFunction.apply
|
||||
def gelu_fwd(x):
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
|
||||
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@@ -63,7 +65,9 @@ def gelu_fwd(x):
|
||||
def gelu_bwd(g, x):
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
||||
1 + tanh_out
|
||||
)
|
||||
return (ff * g).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@@ -76,10 +80,11 @@ class FastGeLUFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
(input,) = ctx.saved_tensors
|
||||
tmp = gelu_bwd(grad_output, input)
|
||||
return tmp
|
||||
|
||||
|
||||
fast_gelu_impl = FastGeLUFunction.apply
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,10 @@ import fused_dense_lib as fused_dense_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
|
||||
from flash_attn.utils.distributed import (
|
||||
all_gather_raw,
|
||||
@@ -18,9 +22,6 @@ from flash_attn.utils.distributed import (
|
||||
reduce_scatter,
|
||||
reduce_scatter_raw,
|
||||
)
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
|
||||
+535
-110
@@ -1,40 +1,73 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
||||
|
||||
import dropout_layer_norm
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
def maybe_align(x, alignment_in_bytes=16):
|
||||
"""Assume that x already has last dim divisible by alignment_in_bytes
|
||||
"""
|
||||
"""Assume that x already has last dim divisible by alignment_in_bytes"""
|
||||
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
||||
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
||||
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
def _dropout_add_layer_norm_forward(
|
||||
x0,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
rowscale,
|
||||
colscale,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon,
|
||||
1.0, 0, None, residual_in_fp32, is_rms_norm
|
||||
x0mat,
|
||||
residualmat,
|
||||
gamma,
|
||||
beta,
|
||||
rowscale,
|
||||
colscale,
|
||||
None,
|
||||
None,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
|
||||
dropout_p, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
def _dropout_add_layer_norm_backward(
|
||||
dz,
|
||||
dx,
|
||||
x,
|
||||
x0,
|
||||
dmask,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma,
|
||||
rowscale,
|
||||
colscale,
|
||||
dropout_p,
|
||||
has_residual,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
@@ -46,10 +79,25 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
||||
x0mat = x0.view((-1, hidden_size)) if x0 is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
||||
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None,
|
||||
dropout_p, 1.0, 0, has_residual, is_rms_norm
|
||||
dzmat,
|
||||
dxmat,
|
||||
xmat,
|
||||
x0mat,
|
||||
dmask,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma,
|
||||
rowscale,
|
||||
colscale,
|
||||
None,
|
||||
None,
|
||||
dropout_p,
|
||||
1.0,
|
||||
0,
|
||||
has_residual,
|
||||
is_rms_norm,
|
||||
)
|
||||
# dresidualmat is None if not has_residual
|
||||
if colscale is None:
|
||||
@@ -59,29 +107,68 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
||||
return dx0mat, dresidualmat, dgamma, dbeta, dcolscale
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
||||
out_subset, dropout_p, epsilon, rowscale_const,
|
||||
out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
def _dropout_add_layer_norm_subset_forward(
|
||||
x0,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale_const,
|
||||
out_numrows,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
||||
out_subset = out_subset.view(-1) if out_subset is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm
|
||||
x0mat,
|
||||
residualmat,
|
||||
gamma,
|
||||
beta,
|
||||
None,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale_const,
|
||||
out_numrows,
|
||||
None,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
|
||||
x0_subset, out_subset, dropout_p, rowscale_const,
|
||||
x0_numrows, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
def _dropout_add_layer_norm_subset_backward(
|
||||
dz,
|
||||
dx,
|
||||
x,
|
||||
x0,
|
||||
dmask,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
rowscale_const,
|
||||
x0_numrows,
|
||||
has_residual,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
@@ -94,10 +181,25 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
|
||||
x0_subset = x0_subset.view(-1) if x0_subset is not None else None
|
||||
out_subset = out_subset.view(-1) if out_subset is not None else None
|
||||
if colscale is not None:
|
||||
assert x0 is not None, 'x0 is required to compute the gradient of colscale'
|
||||
assert x0 is not None, "x0 is required to compute the gradient of colscale"
|
||||
dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset,
|
||||
dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm
|
||||
dzmat,
|
||||
dxmat,
|
||||
xmat,
|
||||
x0mat,
|
||||
dmask,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma,
|
||||
None,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
rowscale_const,
|
||||
x0_numrows,
|
||||
has_residual,
|
||||
is_rms_norm,
|
||||
)
|
||||
# dresidualmat is None if not has_residual
|
||||
if colscale is None:
|
||||
@@ -108,18 +210,44 @@ def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, ga
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
gamma0,
|
||||
beta0,
|
||||
gamma1,
|
||||
beta1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32=False,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
"""Assume that arguments are contiguous and aligned to 16 bytes"""
|
||||
hidden_size = gamma0.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
||||
residualmat = residual.view((-1, hidden_size)) if residual is not None else None
|
||||
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
|
||||
x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
None, residual_in_fp32, is_rms_norm
|
||||
(
|
||||
z0mat,
|
||||
z1mat,
|
||||
xmat,
|
||||
dmask0,
|
||||
dmask1,
|
||||
mu,
|
||||
rsigma,
|
||||
) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd(
|
||||
x0mat,
|
||||
x1mat,
|
||||
residualmat,
|
||||
gamma0,
|
||||
beta0,
|
||||
gamma1,
|
||||
beta1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
None,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
# dmask0 and dmask1 are None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype
|
||||
@@ -127,10 +255,22 @@ def _dropout_add_layer_norm_parallel_residual_forward(
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
||||
dropout_p, has_x1, has_residual, is_rms_norm=False
|
||||
dz0,
|
||||
dz1,
|
||||
dx,
|
||||
x,
|
||||
dmask0,
|
||||
dmask1,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma0,
|
||||
gamma1,
|
||||
dropout_p,
|
||||
has_x1,
|
||||
has_residual,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
"""
|
||||
@@ -139,9 +279,30 @@ def _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0mat = dz0.view(xmat.shape)
|
||||
dz1mat = dz1.view(xmat.shape) if dz1 is not None else None
|
||||
dxmat = dx.view(xmat.shape) if dx is not None else None
|
||||
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
|
||||
dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
||||
dropout_p, has_x1, has_residual, is_rms_norm
|
||||
(
|
||||
dx0mat,
|
||||
dx1mat,
|
||||
dresidualmat,
|
||||
dgamma0,
|
||||
dbeta0,
|
||||
dgamma1,
|
||||
dbeta1,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd(
|
||||
dz0mat,
|
||||
dz1mat,
|
||||
dxmat,
|
||||
xmat,
|
||||
dmask0,
|
||||
dmask1,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma0,
|
||||
gamma1,
|
||||
dropout_p,
|
||||
has_x1,
|
||||
has_residual,
|
||||
is_rms_norm,
|
||||
)
|
||||
# dresidualmat is None if not has_residual
|
||||
return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1
|
||||
@@ -149,8 +310,21 @@ def _dropout_add_layer_norm_parallel_residual_backward(
|
||||
|
||||
class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
def forward(
|
||||
ctx,
|
||||
x0,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
rowscale,
|
||||
colscale,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32=False,
|
||||
prenorm=False,
|
||||
is_rms_norm=False,
|
||||
return_dmask=False,
|
||||
):
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma = maybe_align(gamma.contiguous(), 16)
|
||||
@@ -158,26 +332,43 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
||||
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
x0,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
rowscale,
|
||||
colscale,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
x0_saved = x0 if colscale is not None else None
|
||||
ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale)
|
||||
ctx.save_for_backward(
|
||||
xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale
|
||||
)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = residual is not None
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.has_beta = beta is not None
|
||||
if not return_dmask:
|
||||
return (zmat.view(x0.shape) if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape)))
|
||||
return (
|
||||
zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape))
|
||||
)
|
||||
else:
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
dmask = (
|
||||
dmask.view(x0.shape)
|
||||
if dropout_p > 0.0
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
||||
)
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return ((zmat.view(x0.shape), dmask) if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask))
|
||||
return (
|
||||
(zmat.view(x0.shape), dmask)
|
||||
if not prenorm
|
||||
else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
@@ -189,35 +380,85 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual,
|
||||
ctx.is_rms_norm
|
||||
dz,
|
||||
dx,
|
||||
x,
|
||||
x0,
|
||||
dmask,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma,
|
||||
rowscale,
|
||||
colscale,
|
||||
dropout_p,
|
||||
has_residual,
|
||||
ctx.is_rms_norm,
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None,
|
||||
None, None, None, None, None)
|
||||
return (
|
||||
dx0,
|
||||
dresidual,
|
||||
dgamma,
|
||||
dbeta if ctx.has_beta else None,
|
||||
None,
|
||||
dcolscale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32=False,
|
||||
prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
def forward(
|
||||
ctx,
|
||||
x0,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale_const,
|
||||
out_numrows,
|
||||
residual_in_fp32=False,
|
||||
prenorm=False,
|
||||
is_rms_norm=False,
|
||||
return_dmask=False,
|
||||
):
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma = maybe_align(gamma.contiguous(), 16)
|
||||
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
||||
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
||||
x0,
|
||||
residual,
|
||||
gamma,
|
||||
beta,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale_const,
|
||||
out_numrows,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
# Only need to save x0 if we need to compute gradient wrt colscale
|
||||
x0_saved = x0 if colscale is not None else None
|
||||
x_shape = (-1, *x0.shape[1:])
|
||||
ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale,
|
||||
x0_subset, out_subset)
|
||||
ctx.save_for_backward(
|
||||
xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset
|
||||
)
|
||||
ctx.prenorm = prenorm
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.rowscale_const = rowscale_const
|
||||
@@ -227,14 +468,16 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
ctx.has_beta = beta is not None
|
||||
z_shape = (-1, *x0.shape[1:])
|
||||
if not return_dmask:
|
||||
return (zmat.view(z_shape) if not prenorm
|
||||
else (zmat.view(z_shape), xmat.view(x0.shape)))
|
||||
return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape))
|
||||
else:
|
||||
z = zmat.view(z_shape)
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
dmask = (
|
||||
dmask.view(x0.shape)
|
||||
if dropout_p > 0.0
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
||||
)
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask))
|
||||
return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
@@ -246,20 +489,63 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward(
|
||||
dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p,
|
||||
ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm
|
||||
dz,
|
||||
dx,
|
||||
x,
|
||||
x0,
|
||||
dmask,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma,
|
||||
colscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
ctx.rowscale_const,
|
||||
ctx.x0_numrows,
|
||||
has_residual,
|
||||
ctx.is_rms_norm,
|
||||
)
|
||||
dx0 = dx0mat.view(-1, *x.shape[1:])
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
dcolscale = rest[0] if colscale is not None else None
|
||||
return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None,
|
||||
None, None, None, None, None, None, None, None)
|
||||
return (
|
||||
dx0,
|
||||
dresidual,
|
||||
dgamma,
|
||||
dbeta if ctx.has_beta else None,
|
||||
dcolscale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
def forward(
|
||||
ctx,
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
gamma0,
|
||||
beta0,
|
||||
gamma1,
|
||||
beta1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32=False,
|
||||
prenorm=False,
|
||||
is_rms_norm=False,
|
||||
return_dmask=False,
|
||||
):
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
@@ -267,9 +553,26 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
||||
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
||||
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
||||
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
(
|
||||
z0mat,
|
||||
z1mat,
|
||||
xmat,
|
||||
dmask0,
|
||||
dmask1,
|
||||
mu,
|
||||
rsigma,
|
||||
) = _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
gamma0,
|
||||
beta0,
|
||||
gamma1,
|
||||
beta1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32,
|
||||
is_rms_norm,
|
||||
)
|
||||
ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma)
|
||||
ctx.prenorm = prenorm
|
||||
@@ -282,13 +585,21 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
if not return_dmask:
|
||||
return z if not prenorm else (*z, xmat.view(x0.shape))
|
||||
else:
|
||||
dmask0 = (dmask0.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
dmask0 = (
|
||||
dmask0.view(x0.shape)
|
||||
if dropout_p > 0.0
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
||||
)
|
||||
dmask1 = (
|
||||
dmask1.view(x0.shape)
|
||||
if dropout_p > 0.0 and x1 is not None
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)
|
||||
)
|
||||
ctx.mark_non_differentiable(dmask0)
|
||||
ctx.mark_non_differentiable(dmask1)
|
||||
return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
|
||||
return (
|
||||
(*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz0, dz1, *args):
|
||||
@@ -299,63 +610,170 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
dropout_p = ctx.dropout_p
|
||||
has_x1 = ctx.has_x1
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1,
|
||||
has_residual, ctx.is_rms_norm
|
||||
(
|
||||
dx0mat,
|
||||
dx1mat,
|
||||
dresidualmat,
|
||||
dgamma0,
|
||||
dbeta0,
|
||||
dgamma1,
|
||||
dbeta1,
|
||||
) = _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0,
|
||||
dz1,
|
||||
dx,
|
||||
x,
|
||||
dmask0,
|
||||
dmask1,
|
||||
mu,
|
||||
rsigma,
|
||||
gamma0,
|
||||
gamma1,
|
||||
dropout_p,
|
||||
has_x1,
|
||||
has_residual,
|
||||
ctx.is_rms_norm,
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
|
||||
return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1,
|
||||
dbeta1 if ctx.has_beta else None, None, None, None, None, None, None)
|
||||
return (
|
||||
dx0,
|
||||
dx1,
|
||||
dresidual,
|
||||
dgamma0,
|
||||
dbeta0 if ctx.has_beta else None,
|
||||
dgamma1,
|
||||
dbeta1 if ctx.has_beta else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def layer_norm(x, weight, bias, epsilon):
|
||||
return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False)
|
||||
|
||||
|
||||
def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
layerscale=None, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
def dropout_add_layer_norm(
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale=None,
|
||||
layerscale=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
False, return_dropout_mask
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
rowscale,
|
||||
layerscale,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32,
|
||||
prenorm,
|
||||
False,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
||||
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
def dropout_add_layer_norm_subset(
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
layerscale=None,
|
||||
x0_subset=None,
|
||||
out_subset=None,
|
||||
rowscale_const=1.0,
|
||||
out_numrows=0,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormSubsetFn.apply(
|
||||
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
layerscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale_const,
|
||||
out_numrows,
|
||||
residual_in_fp32,
|
||||
prenorm,
|
||||
False,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_layer_norm_parallel_residual(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False,
|
||||
residual_in_fp32=False, return_dropout_mask=False
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
weight0,
|
||||
bias0,
|
||||
weight1,
|
||||
bias1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormParallelResidualFn.apply(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
False, return_dropout_mask
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
weight0,
|
||||
bias0,
|
||||
weight1,
|
||||
bias1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32,
|
||||
prenorm,
|
||||
False,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
class DropoutAddLayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
prenorm=False,
|
||||
p=0.0,
|
||||
eps=1e-5,
|
||||
residual_in_fp32=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.prenorm = prenorm
|
||||
self.p = p
|
||||
@@ -370,6 +788,13 @@ class DropoutAddLayerNorm(torch.nn.Module):
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x0, residual=None):
|
||||
return dropout_add_layer_norm(x0, residual, self.weight, self.bias,
|
||||
self.p if self.training else 0.0, self.eps,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
return dropout_add_layer_norm(
|
||||
x0,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.p if self.training else 0.0,
|
||||
self.eps,
|
||||
prenorm=self.prenorm,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
|
||||
+113
-28
@@ -4,60 +4,130 @@
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
|
||||
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
|
||||
from flash_attn.ops.layer_norm import (
|
||||
DropoutAddLayerNormFn,
|
||||
DropoutAddLayerNormParallelResidualFn,
|
||||
DropoutAddLayerNormSubsetFn,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm(x, weight, epsilon):
|
||||
return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
|
||||
False, True)
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x, None, weight, None, None, None, 0.0, epsilon, False, False, True
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
layerscale=None, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
def dropout_add_rms_norm(
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale=None,
|
||||
layerscale=None,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormFn.apply(
|
||||
x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
True, return_dropout_mask
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
rowscale,
|
||||
layerscale,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32,
|
||||
prenorm,
|
||||
True,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
|
||||
x0_subset=None, out_subset=None, rowscale_const=1.0,
|
||||
out_numrows=0, prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
def dropout_add_rms_norm_subset(
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
layerscale=None,
|
||||
x0_subset=None,
|
||||
out_subset=None,
|
||||
rowscale_const=1.0,
|
||||
out_numrows=0,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormSubsetFn.apply(
|
||||
x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
|
||||
x0,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
layerscale,
|
||||
x0_subset,
|
||||
out_subset,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
rowscale_const,
|
||||
out_numrows,
|
||||
residual_in_fp32,
|
||||
prenorm,
|
||||
True,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
def dropout_add_rms_norm_parallel_residual(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1,
|
||||
dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
weight0,
|
||||
bias0,
|
||||
weight1,
|
||||
bias1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
prenorm=False,
|
||||
residual_in_fp32=False,
|
||||
return_dropout_mask=False,
|
||||
):
|
||||
"""residual_in_fp32 only has an effect if residual is None.
|
||||
Otherwise residual dtype is residual.dtype.
|
||||
"""
|
||||
return DropoutAddLayerNormParallelResidualFn.apply(
|
||||
x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
|
||||
True, return_dropout_mask
|
||||
x0,
|
||||
x1,
|
||||
residual,
|
||||
weight0,
|
||||
bias0,
|
||||
weight1,
|
||||
bias1,
|
||||
dropout_p,
|
||||
epsilon,
|
||||
residual_in_fp32,
|
||||
prenorm,
|
||||
True,
|
||||
return_dropout_mask,
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
@@ -68,22 +138,37 @@ class RMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class DropoutAddRMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
prenorm=False,
|
||||
p=0.0,
|
||||
eps=1e-5,
|
||||
residual_in_fp32=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.prenorm = prenorm
|
||||
self.p = p
|
||||
self.eps = eps
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
|
||||
def forward(self, x0, residual=None):
|
||||
return dropout_add_rms_norm(x0, residual, self.weight, None,
|
||||
self.p if self.training else 0.0, self.eps,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
return dropout_add_rms_norm(
|
||||
x0,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
self.p if self.training else 0.0,
|
||||
self.eps,
|
||||
prenorm=self.prenorm,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Optional
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
_sqrt2pi = math.sqrt(2.0 / math.pi)
|
||||
_sqrt1_2 = math.sqrt(1.0 / 2)
|
||||
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
|
||||
@@ -142,6 +141,7 @@ def gelu_grad(x):
|
||||
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
|
||||
return cdf + x * pdf
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_approx(x):
|
||||
"""
|
||||
@@ -157,6 +157,6 @@ def gelu_approx_grad(x):
|
||||
# CREDITS: Fast implementation proposed in
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
|
||||
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
return 0.5 * x * (
|
||||
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
|
||||
) + 0.5 * (1 + tanh_out)
|
||||
return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
||||
1 + tanh_out
|
||||
)
|
||||
|
||||
+173
-56
@@ -9,8 +9,14 @@ from torch.autograd.function import FunctionCtx
|
||||
from torch.cuda.amp import custom_fwd
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad
|
||||
|
||||
from flash_attn.ops.triton.k_activations import (
|
||||
gelu,
|
||||
gelu_approx,
|
||||
gelu_approx_grad,
|
||||
gelu_grad,
|
||||
squared_relu,
|
||||
squared_relu_grad,
|
||||
)
|
||||
|
||||
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
|
||||
|
||||
@@ -28,7 +34,12 @@ def get_configs_io_bound():
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
|
||||
{
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SPLIT_K": 1,
|
||||
},
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
@@ -43,29 +54,75 @@ def get_configs_io_bound():
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
|
||||
),
|
||||
# good for int8
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
|
||||
),
|
||||
]
|
||||
+ get_configs_io_bound(),
|
||||
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
|
||||
prune_configs_by={
|
||||
"early_config_prune": early_config_prune,
|
||||
"perf_model": estimate_matmul_time,
|
||||
"top_k": 10,
|
||||
},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
@@ -204,7 +261,7 @@ def triton_linear_act(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: str = 'id',
|
||||
activation: str = "id",
|
||||
save_act_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -221,7 +278,7 @@ def triton_linear_act(
|
||||
# dtype = torch.get_autocast_gpu_dtype()
|
||||
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
|
||||
|
||||
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
|
||||
assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
|
||||
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
@@ -233,12 +290,20 @@ def triton_linear_act(
|
||||
weight = weight.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
|
||||
assert (
|
||||
x.dtype == weight.dtype
|
||||
), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
|
||||
if bias is not None:
|
||||
assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
|
||||
assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
|
||||
assert (
|
||||
x.dtype == bias.dtype
|
||||
), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
|
||||
assert (
|
||||
x_reshaped.shape[1] == weight.shape[1]
|
||||
), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
|
||||
|
||||
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
|
||||
assert (
|
||||
bias is None or bias.shape[0] == weight.shape[0]
|
||||
), "Incompatible dimensions in between weight and bias"
|
||||
|
||||
M, K = x_reshaped.shape
|
||||
N, K = weight.shape
|
||||
@@ -278,35 +343,83 @@ def triton_linear_act(
|
||||
if not save_act_input:
|
||||
return output.reshape(*batch_shape, output.shape[-1])
|
||||
else:
|
||||
return (output.reshape(*batch_shape, output.shape[-1]),
|
||||
act_input.reshape(*batch_shape, act_input.shape[-1]))
|
||||
return (
|
||||
output.reshape(*batch_shape, output.shape[-1]),
|
||||
act_input.reshape(*batch_shape, act_input.shape[-1]),
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2
|
||||
),
|
||||
# good for int8
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2
|
||||
),
|
||||
]
|
||||
+ get_configs_io_bound(),
|
||||
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
|
||||
prune_configs_by={
|
||||
"early_config_prune": early_config_prune,
|
||||
"perf_model": estimate_matmul_time,
|
||||
"top_k": 10,
|
||||
},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
@@ -395,7 +508,7 @@ def kernel_bwd(
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# optional: fused activation (while the data is in shared memory)
|
||||
if ACTIVATION != 'id':
|
||||
if ACTIVATION != "id":
|
||||
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
|
||||
act_input = tl.load(act_in_ptrs).to(acc.dtype)
|
||||
if ACTIVATION == "gelu":
|
||||
@@ -418,7 +531,7 @@ def kernel_bwd(
|
||||
def triton_dgrad_act(
|
||||
grad_output: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
activation: str = 'id',
|
||||
activation: str = "id",
|
||||
act_input: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -430,7 +543,7 @@ def triton_dgrad_act(
|
||||
:param act_input: an optional tensor to save the activation inputs (for backward)
|
||||
:return: result tensor
|
||||
"""
|
||||
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
|
||||
assert activation in ["id", "gelu", "gelu_approx", "squared_relu"]
|
||||
|
||||
batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
@@ -441,10 +554,14 @@ def triton_dgrad_act(
|
||||
if weight.stride(0) > 1 and weight.stride(1) > 1:
|
||||
weight = weight.contiguous()
|
||||
|
||||
assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
|
||||
assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
|
||||
if activation != 'id':
|
||||
assert act_input is not None, f'act_input is required for activation {activation}'
|
||||
assert (
|
||||
grad_output.dtype == weight.dtype
|
||||
), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
|
||||
assert (
|
||||
grad_output_reshaped.shape[1] == weight.shape[0]
|
||||
), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
|
||||
if activation != "id":
|
||||
assert act_input is not None, f"act_input is required for activation {activation}"
|
||||
|
||||
# M, N, K in bwd are different from M, N, K in fwd
|
||||
M, K = grad_output_reshaped.shape
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
|
||||
# to naive implementation.
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
|
||||
from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd
|
||||
from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd
|
||||
from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act
|
||||
|
||||
|
||||
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
|
||||
@@ -23,8 +21,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
"""
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
|
||||
for a in [x, weight1, bias1, weight2, bias2]]
|
||||
x, weight1, bias1, weight2, bias2 = [
|
||||
a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2]
|
||||
]
|
||||
is_bf16 = x.dtype == torch.bfloat16
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
x = x.contiguous()
|
||||
@@ -35,13 +34,18 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if is_bf16:
|
||||
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
act_input = fused_dense_cuda.linear_bias_forward(
|
||||
x.reshape(batch_dim, n), weight1, bias1
|
||||
)
|
||||
output1 = sqrelu_fwd(act_input)
|
||||
else:
|
||||
save_act_input = checkpoint_lvl != 2
|
||||
result = triton_linear_act(
|
||||
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
|
||||
save_act_input=save_act_input
|
||||
x.reshape(batch_dim, n),
|
||||
weight1,
|
||||
bias1,
|
||||
activation="squared_relu",
|
||||
save_act_input=save_act_input,
|
||||
)
|
||||
if save_act_input:
|
||||
output1, act_input = result
|
||||
@@ -69,16 +73,21 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
if checkpoint_lvl == 0:
|
||||
act_input, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
act_input, = rest
|
||||
(act_input,) = rest
|
||||
output1 = sqrelu_fwd(act_input)
|
||||
elif checkpoint_lvl == 2:
|
||||
if is_bf16:
|
||||
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
act_input = fused_dense_cuda.linear_bias_forward(
|
||||
x.reshape(batch_dim, n), weight1, bias1
|
||||
)
|
||||
output1 = sqrelu_fwd(act_input)
|
||||
else:
|
||||
output1, act_input = triton_linear_act(
|
||||
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
|
||||
save_act_input=True
|
||||
x.reshape(batch_dim, n),
|
||||
weight1,
|
||||
bias1,
|
||||
activation="squared_relu",
|
||||
save_act_input=True,
|
||||
)
|
||||
|
||||
if is_bf16:
|
||||
@@ -92,8 +101,9 @@ class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
else:
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
|
||||
grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu',
|
||||
act_input=act_input)
|
||||
grad_act_input = triton_dgrad_act(
|
||||
grad_output, weight2, activation="squared_relu", act_input=act_input
|
||||
)
|
||||
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
|
||||
x.reshape(batch_dim, n), weight1, grad_act_input
|
||||
)
|
||||
@@ -104,9 +114,17 @@ fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
|
||||
|
||||
|
||||
class FusedDenseSqreluDense(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True,
|
||||
checkpoint_lvl=0, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
bias1=True,
|
||||
bias2=True,
|
||||
checkpoint_lvl=0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
@@ -114,7 +132,7 @@ class FusedDenseSqreluDense(nn.Module):
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
@@ -126,6 +144,6 @@ class FusedDenseSqreluDense(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.is_cuda
|
||||
return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias,
|
||||
self.fc2.weight, self.fc2.bias,
|
||||
self.checkpoint_lvl)
|
||||
return fused_dense_sqrelu_dense_function(
|
||||
x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl
|
||||
)
|
||||
|
||||
+155
-58
@@ -5,31 +5,43 @@ import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
|
||||
def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
|
||||
def benchmark_forward(
|
||||
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
|
||||
):
|
||||
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
||||
if verbose:
|
||||
print(desc, '- Forward pass')
|
||||
print(desc, "- Forward pass")
|
||||
|
||||
def amp_wrapper(*inputs, **kwinputs):
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
fn(*inputs, **kwinputs)
|
||||
|
||||
t = benchmark.Timer(
|
||||
stmt='fn_amp(*inputs, **kwinputs)',
|
||||
globals={'fn_amp': amp_wrapper, 'inputs': inputs, 'kwinputs': kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
stmt="fn_amp(*inputs, **kwinputs)",
|
||||
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
|
||||
def benchmark_backward(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=None,
|
||||
repeats=10,
|
||||
desc="",
|
||||
verbose=True,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
**kwinputs,
|
||||
):
|
||||
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
|
||||
if verbose:
|
||||
print(desc, '- Backward pass')
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
print(desc, "- Backward pass")
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
@@ -37,7 +49,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
|
||||
grad = torch.randn_like(y)
|
||||
else:
|
||||
if grad.shape != y.shape:
|
||||
raise RuntimeError('Grad shape does not match output shape')
|
||||
raise RuntimeError("Grad shape does not match output shape")
|
||||
|
||||
def f(*inputs, y, grad):
|
||||
# Set .grad to None to avoid extra operation of gradient accumulation
|
||||
for x in inputs:
|
||||
@@ -46,22 +59,31 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
|
||||
y.backward(grad, retain_graph=True)
|
||||
|
||||
t = benchmark.Timer(
|
||||
stmt='f(*inputs, y=y, grad=grad)',
|
||||
globals={'f': f, 'inputs': inputs, 'y': y, 'grad': grad},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
stmt="f(*inputs, y=y, grad=grad)",
|
||||
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
def benchmark_combined(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=None,
|
||||
repeats=10,
|
||||
desc="",
|
||||
verbose=True,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
**kwinputs,
|
||||
):
|
||||
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
||||
if verbose:
|
||||
print(desc, '- Forward + Backward pass')
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
print(desc, "- Forward + Backward pass")
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
@@ -69,68 +91,142 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
|
||||
grad = torch.randn_like(y)
|
||||
else:
|
||||
if grad.shape != y.shape:
|
||||
raise RuntimeError('Grad shape does not match output shape')
|
||||
raise RuntimeError("Grad shape does not match output shape")
|
||||
|
||||
def f(grad, *inputs, **kwinputs):
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
y = fn(*inputs, **kwinputs)
|
||||
if type(y) is tuple:
|
||||
y = y[0]
|
||||
y.backward(grad, retain_graph=True)
|
||||
|
||||
t = benchmark.Timer(
|
||||
stmt='f(grad, *inputs, **kwinputs)',
|
||||
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
stmt="f(grad, *inputs, **kwinputs)",
|
||||
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
m = t.timeit(repeats)
|
||||
if verbose:
|
||||
print(m)
|
||||
return t, m
|
||||
|
||||
|
||||
def benchmark_fwd_bwd(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
def benchmark_fwd_bwd(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=None,
|
||||
repeats=10,
|
||||
desc="",
|
||||
verbose=True,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
**kwinputs,
|
||||
):
|
||||
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
||||
return (
|
||||
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_forward(
|
||||
fn,
|
||||
*inputs,
|
||||
repeats=repeats,
|
||||
desc=desc,
|
||||
verbose=verbose,
|
||||
amp=amp,
|
||||
amp_dtype=amp_dtype,
|
||||
**kwinputs,
|
||||
),
|
||||
benchmark_backward(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=grad,
|
||||
repeats=repeats,
|
||||
desc=desc,
|
||||
verbose=verbose,
|
||||
amp=amp,
|
||||
amp_dtype=amp_dtype,
|
||||
**kwinputs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
|
||||
amp_dtype=torch.float16, **kwinputs):
|
||||
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
|
||||
def benchmark_all(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=None,
|
||||
repeats=10,
|
||||
desc="",
|
||||
verbose=True,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
**kwinputs,
|
||||
):
|
||||
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
||||
return (
|
||||
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
|
||||
amp=amp, amp_dtype=amp_dtype, **kwinputs),
|
||||
benchmark_forward(
|
||||
fn,
|
||||
*inputs,
|
||||
repeats=repeats,
|
||||
desc=desc,
|
||||
verbose=verbose,
|
||||
amp=amp,
|
||||
amp_dtype=amp_dtype,
|
||||
**kwinputs,
|
||||
),
|
||||
benchmark_backward(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=grad,
|
||||
repeats=repeats,
|
||||
desc=desc,
|
||||
verbose=verbose,
|
||||
amp=amp,
|
||||
amp_dtype=amp_dtype,
|
||||
**kwinputs,
|
||||
),
|
||||
benchmark_combined(
|
||||
fn,
|
||||
*inputs,
|
||||
grad=grad,
|
||||
repeats=repeats,
|
||||
desc=desc,
|
||||
verbose=verbose,
|
||||
amp=amp,
|
||||
amp_dtype=amp_dtype,
|
||||
**kwinputs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
|
||||
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
|
||||
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
|
||||
def pytorch_profiler(
|
||||
fn,
|
||||
*inputs,
|
||||
trace_filename=None,
|
||||
backward=False,
|
||||
amp=False,
|
||||
amp_dtype=torch.float16,
|
||||
cpu=False,
|
||||
verbose=True,
|
||||
**kwinputs,
|
||||
):
|
||||
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
|
||||
if backward:
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
g = torch.randn_like(fn(*inputs, **kwinputs))
|
||||
for _ in range(30): # Warm up
|
||||
for _ in range(30): # Warm up
|
||||
if backward:
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
out = fn(*inputs, **kwinputs)
|
||||
# Backward should be done outside autocast
|
||||
if backward:
|
||||
out.backward(g, retain_graph=True)
|
||||
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
|
||||
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
|
||||
torch.profiler.ProfilerActivity.CUDA
|
||||
]
|
||||
with torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=True,
|
||||
@@ -141,9 +237,10 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
|
||||
for x in inputs:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.grad = None
|
||||
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
|
||||
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
||||
out = fn(*inputs, **kwinputs)
|
||||
if backward: out.backward(g, retain_graph=True)
|
||||
if backward:
|
||||
out.backward(g, retain_graph=True)
|
||||
if verbose:
|
||||
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
||||
print(prof.key_averages().table(row_limit=50))
|
||||
@@ -151,14 +248,14 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
|
||||
prof.export_chrome_trace(trace_filename)
|
||||
|
||||
|
||||
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
|
||||
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
fn(*inputs, **kwinputs)
|
||||
torch.cuda.synchronize()
|
||||
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
|
||||
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
|
||||
if verbose:
|
||||
print(f'{desc} max memory: {mem}GB')
|
||||
print(f"{desc} max memory: {mem}GB")
|
||||
torch.cuda.empty_cache()
|
||||
return mem
|
||||
|
||||
@@ -17,10 +17,12 @@ if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||
# Raw operation, does not support autograd, but does support async
|
||||
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
|
||||
group=process_group, async_op=async_op)
|
||||
output = torch.empty(
|
||||
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
@@ -28,11 +30,12 @@ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool =
|
||||
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||
world_size = torch.distributed.get_world_size(process_group)
|
||||
assert input_.shape[0] % world_size == 0
|
||||
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
|
||||
dtype=input_.dtype, device=input_.device)
|
||||
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||
group=process_group,
|
||||
async_op=async_op)
|
||||
output = torch.empty(
|
||||
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
handle = torch.distributed.reduce_scatter_tensor(
|
||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||
)
|
||||
return output, handle
|
||||
|
||||
|
||||
@@ -102,8 +105,9 @@ all_reduce = AllReduceFunc.apply
|
||||
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _shared_params=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
pamams_shared = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_shared_params', False)}
|
||||
pamams_shared = {
|
||||
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
||||
}
|
||||
for _, p in sorted(pamams_shared.items()):
|
||||
with torch.no_grad():
|
||||
# Broadcast needs src to be global rank, not group rank
|
||||
@@ -116,8 +120,9 @@ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
||||
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
||||
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
||||
params_seqparallel = {name: p for name, p in model.named_parameters()
|
||||
if getattr(p, '_sequence_parallel', False)}
|
||||
params_seqparallel = {
|
||||
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
||||
}
|
||||
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
||||
if grads:
|
||||
with torch.no_grad():
|
||||
|
||||
+141
-60
@@ -1,18 +1,15 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
||||
from typing import Optional, Union, Sequence, Callable
|
||||
import gc
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from torch import Tensor
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
||||
|
||||
|
||||
@@ -20,6 +17,7 @@ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoder
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_sequence_len: int
|
||||
max_batch_size: int
|
||||
sequence_len_offset: int = 0
|
||||
@@ -38,11 +36,13 @@ def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
@@ -54,7 +54,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
return logits.argmax(dim=-1)
|
||||
else:
|
||||
if top_p > 0.0:
|
||||
assert top_p <= 1.0, 'top-p should be in (0, 1].'
|
||||
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
@@ -62,17 +62,31 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
||||
]
|
||||
else:
|
||||
logits_top = logits / temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
|
||||
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1,
|
||||
fused_ft_kernel=False, cg=False, timing=False):
|
||||
def decode(
|
||||
input_ids,
|
||||
model,
|
||||
max_length,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
temperature=1.0,
|
||||
eos_token_id=None,
|
||||
teacher_outputs=None,
|
||||
vocab_size=None,
|
||||
tensor_parallel=1,
|
||||
fused_ft_kernel=False,
|
||||
cg=False,
|
||||
timing=False,
|
||||
):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
@@ -92,19 +106,24 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
||||
if cg:
|
||||
assert fused_ft_kernel
|
||||
if not hasattr(model, '_decoding_cache'):
|
||||
if not hasattr(model, "_decoding_cache"):
|
||||
model._decoding_cache = None
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, model._decoding_cache, batch_size, seqlen_og, max_length,
|
||||
tensor_parallel=tensor_parallel
|
||||
model,
|
||||
model._decoding_cache,
|
||||
batch_size,
|
||||
seqlen_og,
|
||||
max_length,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
inference_params = model._decoding_cache.inference_params
|
||||
inference_params.max_sequence_len = max_length
|
||||
inference_params.max_batch_size = batch_size
|
||||
inference_params.sequence_len_offset = 0
|
||||
else:
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
inference_params = InferenceParams(
|
||||
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
|
||||
)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
if timing:
|
||||
@@ -123,18 +142,32 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
sequences = [next_token]
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
while True:
|
||||
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params.sequence_len_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
if not cg:
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params, last_token_only=True).logits
|
||||
logits = model(
|
||||
rearrange(next_token, "b -> b 1"),
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
else:
|
||||
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
|
||||
inference_params.sequence_len_offset)
|
||||
logits = model._decoding_cache.run(
|
||||
rearrange(next_token, "b -> b 1"),
|
||||
position_ids,
|
||||
inference_params.sequence_len_offset,
|
||||
)
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
|
||||
if (
|
||||
teacher_outputs is None
|
||||
or teacher_output_len <= inference_params.sequence_len_offset + 1
|
||||
):
|
||||
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
|
||||
@@ -148,30 +181,45 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
if tensor_parallel > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
||||
scores=tuple(scores)
|
||||
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores)
|
||||
)
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
return_dict_in_generate=False, output_scores=False, **kwargs):
|
||||
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
||||
temperature=temperature, **kwargs)
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
max_length,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
temperature=1.0,
|
||||
return_dict_in_generate=False,
|
||||
output_scores=False,
|
||||
**kwargs,
|
||||
):
|
||||
output = decode(
|
||||
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
|
||||
)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
|
||||
device, dtype=torch.float16):
|
||||
def allocate_inference_cache(
|
||||
max_batch_size,
|
||||
max_seqlen,
|
||||
nheads,
|
||||
headdim,
|
||||
layers: Union[int, Sequence],
|
||||
device,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert headdim % packsize == 0
|
||||
@@ -179,9 +227,13 @@ def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers
|
||||
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
|
||||
if isinstance(layers, int):
|
||||
layers = range(layers)
|
||||
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
|
||||
torch.empty(v_cache_shape, device=device, dtype=dtype))
|
||||
for i in layers}
|
||||
return {
|
||||
i: (
|
||||
torch.empty(k_cache_shape, device=device, dtype=dtype),
|
||||
torch.empty(v_cache_shape, device=device, dtype=dtype),
|
||||
)
|
||||
for i in layers
|
||||
}
|
||||
|
||||
|
||||
def seqlen_to_seqlen_type(seqlen: int) -> int:
|
||||
@@ -211,49 +263,70 @@ class DecodingCGCache:
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
|
||||
dtype=None, n_warmups=2):
|
||||
def update_graph_cache(
|
||||
model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
|
||||
):
|
||||
if cache is None:
|
||||
cache = DecodingCGCache()
|
||||
param_example = next(iter(model.parameters()))
|
||||
device = param_example.device
|
||||
if dtype is None:
|
||||
dtype = param_example.dtype
|
||||
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
|
||||
or max_seqlen > cache.max_seqlen): # Invalidate the cache
|
||||
if (
|
||||
(device, dtype) != (cache.device, cache.dtype)
|
||||
or batch_size > cache.max_batch_size
|
||||
or max_seqlen > cache.max_seqlen
|
||||
): # Invalidate the cache
|
||||
cache.callables = {}
|
||||
cache.mempool = None
|
||||
cache.inference_params = None
|
||||
gc.collect()
|
||||
cache.device, cache.dtype = device, dtype
|
||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||
if hasattr(model, 'allocate_inference_cache'):
|
||||
if hasattr(model, "allocate_inference_cache"):
|
||||
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
||||
else:
|
||||
headdim = getattr(model.config, 'head_dim',
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
headdim = getattr(
|
||||
model.config,
|
||||
"head_dim",
|
||||
model.config.hidden_size // model.config.num_attention_heads,
|
||||
)
|
||||
inf_cache = allocate_inference_cache(
|
||||
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
||||
model.config.num_hidden_layers, device, dtype
|
||||
batch_size,
|
||||
max_seqlen,
|
||||
model.config.num_attention_heads // tensor_parallel,
|
||||
headdim,
|
||||
model.config.num_hidden_layers,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||
cache.inference_params = InferenceParams(
|
||||
max_sequence_len=max_seqlen, max_batch_size=batch_size,
|
||||
sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True,
|
||||
lengths_per_sample=lengths_per_sample
|
||||
max_sequence_len=max_seqlen,
|
||||
max_batch_size=batch_size,
|
||||
sequence_len_offset=seqlen_og,
|
||||
key_value_memory_dict=inf_cache,
|
||||
fused_ft_kernel=True,
|
||||
lengths_per_sample=lengths_per_sample,
|
||||
)
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||
if (batch_size, s_type) not in cache.callables:
|
||||
max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[batch_size, s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool,
|
||||
n_warmups=n_warmups
|
||||
model,
|
||||
cache.inference_params,
|
||||
batch_size,
|
||||
max_seqlen_,
|
||||
mempool=cache.mempool,
|
||||
n_warmups=n_warmups,
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
batch_size = input_ids.shape[0]
|
||||
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](
|
||||
input_ids, position_ids, seqlen
|
||||
)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing
|
||||
@@ -275,8 +348,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(n_warmups):
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
# which requires that graph launch and non-captured launch to not overlap (I think,
|
||||
@@ -288,8 +365,12 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=mempool):
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
|
||||
@@ -3,13 +3,18 @@ from functools import partial
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
# If not fp32, then we don't want to load directly to the GPU
|
||||
mapped_device = 'cpu' if dtype not in [torch.float32, None] else device
|
||||
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
||||
is_sharded = False
|
||||
load_safe = False
|
||||
resolved_archive_file = None
|
||||
@@ -20,19 +25,23 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
|
||||
|
||||
if os.path.isfile(weights_path):
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
elif os.path.isfile(weights_index_path):
|
||||
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(safe_weights_path):
|
||||
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
load_safe = True
|
||||
elif os.path.isfile(safe_weights_index_path):
|
||||
resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME,
|
||||
_raise_exceptions_for_missing_entries=False)
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
is_sharded = True
|
||||
load_safe = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user