Skip to content

Commit 5aa802b

Browse files
authored
add xformers (lm-sys#1970)
1 parent 70e01d3 commit 5aa802b

2 files changed

Lines changed: 142 additions & 0 deletions

File tree

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3+
"""
4+
5+
import logging
6+
import math
7+
from typing import Optional, Tuple
8+
9+
import torch
10+
import transformers.models.llama.modeling_llama
11+
from torch import nn
12+
13+
try:
14+
import xformers.ops
15+
except ImportError:
16+
logging.error("xformers not found! Please install it before trying to use it.")
17+
18+
19+
def replace_llama_attn_with_xformers_attn():
20+
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21+
22+
23+
def xformers_forward(
24+
self,
25+
hidden_states: torch.Tensor,
26+
attention_mask: Optional[torch.Tensor] = None,
27+
position_ids: Optional[torch.LongTensor] = None,
28+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
29+
output_attentions: bool = False,
30+
use_cache: bool = False,
31+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32+
# pylint: disable=duplicate-code
33+
bsz, q_len, _ = hidden_states.size()
34+
35+
query_states = (
36+
self.q_proj(hidden_states)
37+
.view(bsz, q_len, self.num_heads, self.head_dim)
38+
.transpose(1, 2)
39+
)
40+
key_states = (
41+
self.k_proj(hidden_states)
42+
.view(bsz, q_len, self.num_heads, self.head_dim)
43+
.transpose(1, 2)
44+
)
45+
value_states = (
46+
self.v_proj(hidden_states)
47+
.view(bsz, q_len, self.num_heads, self.head_dim)
48+
.transpose(1, 2)
49+
)
50+
51+
kv_seq_len = key_states.shape[-2]
52+
if past_key_value is not None:
53+
kv_seq_len += past_key_value[0].shape[-2]
54+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55+
(
56+
query_states,
57+
key_states,
58+
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59+
query_states, key_states, cos, sin, position_ids
60+
)
61+
# [bsz, nh, t, hd]
62+
63+
if past_key_value is not None:
64+
# reuse k, v, self_attention
65+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
66+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
67+
68+
past_key_value = (key_states, value_states) if use_cache else None
69+
70+
# We only apply xformers optimizations if we don't need to output the whole attention matrix
71+
if not output_attentions:
72+
query_states = query_states.transpose(1, 2)
73+
key_states = key_states.transpose(1, 2)
74+
value_states = value_states.transpose(1, 2)
75+
76+
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77+
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78+
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
80+
attn_output = xformers.ops.memory_efficient_attention(
81+
query_states, key_states, value_states, attn_bias=None
82+
)
83+
else:
84+
# input and output should be of form (bsz, q_len, num_heads, head_dim)
85+
attn_output = xformers.ops.memory_efficient_attention(
86+
query_states,
87+
key_states,
88+
value_states,
89+
attn_bias=xformers.ops.LowerTriangularMask(),
90+
)
91+
attn_weights = None
92+
else:
93+
attn_weights = torch.matmul(
94+
query_states, key_states.transpose(2, 3)
95+
) / math.sqrt(self.head_dim)
96+
97+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98+
raise ValueError(
99+
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100+
f" {attn_weights.size()}"
101+
)
102+
103+
if attention_mask is not None:
104+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105+
raise ValueError(
106+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107+
)
108+
attn_weights = attn_weights + attention_mask
109+
attn_weights = torch.max(
110+
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111+
)
112+
113+
# upcast attention to fp32
114+
attn_weights = nn.functional.softmax(
115+
attn_weights, dim=-1, dtype=torch.float32
116+
).to(query_states.dtype)
117+
attn_output = torch.matmul(attn_weights, value_states)
118+
119+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120+
raise ValueError(
121+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122+
f" {attn_output.size()}"
123+
)
124+
125+
attn_output = attn_output.transpose(1, 2)
126+
127+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128+
attn_output = self.o_proj(attn_output)
129+
return attn_output, attn_weights, past_key_value

fastchat/train/train_xformers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2+
3+
# Need to call this before importing transformers.
4+
from fastchat.train.llama_xformers_attn_monkey_patch import (
5+
replace_llama_attn_with_xformers_attn,
6+
)
7+
8+
replace_llama_attn_with_xformers_attn()
9+
10+
from fastchat.train.train import train
11+
12+
if __name__ == "__main__":
13+
train()

0 commit comments

Comments
 (0)