Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade jax to 0.4.33 for A3 Mega #730

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 13 additions & 36 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@

# pytype: disable=import-error # pylint: disable=import-error
from jax import lax
from jax._src.cudnn.fused_attention_stablehlo import (
MaskType,
_dot_product_attention,
_normalize_layout,
check_cudnn_version,
)
from jax._src.cudnn.fused_attention_stablehlo import MaskType, dot_product_attention
from jax._src.lib import cuda_versions
from jax.experimental import pallas as pl

Expand Down Expand Up @@ -730,38 +725,20 @@ def cudnn_dot_product_attention(

if qkv_layout != "BTNH":
raise NotImplementedError(f"Unsupported qkv_layout: {qkv_layout}")
# Check if cuDNN is installed.
cudnn_version = check_cudnn_version()
# Support Ampere and Hopper only for now.
_check_local_compute_capability((80, 90))
mask_type = MaskType.NO_MASK if not causal else MaskType.CAUSAL
layout = _normalize_layout(qkv_layout)

has_bias = bias is not None
has_mask = mask is not None
has_dbias = False
variadic_args = (has_bias, has_mask, has_dbias)
if bias is None:
bias = jnp.zeros(0, dtype=query.dtype)
if mask is None:
mask = jnp.zeros(0, dtype=query.dtype)
q_seqlen = jnp.zeros(0, dtype=query.dtype)
kv_seqlen = jnp.zeros(0, dtype=query.dtype)
# pylint: disable-next=too-many-function-args
output = _dot_product_attention(
query,
key,
value,
bias,
mask,
q_seqlen,
kv_seqlen,
softmax_scale,
seed,
dropout_rate,
variadic_args,
mask_type,
layout.value,
cudnn_version,

output = dot_product_attention(
query=query,
key=key,
value=value,
bias=bias,
mask=mask,
scale=softmax_scale,
seed=seed,
dropout_rate=dropout_rate,
mask_type=mask_type,
qkv_layout=qkv_layout,
)
return output
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ core = [
"absl-py==2.1.0",
"chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25.
"importlab==0.7", # breaks pytype on 0.8
"jax==0.4.30",
"jaxlib==0.4.30",
"jax==0.4.33",
"jaxlib==0.4.33",
"nltk==3.7", # for text preprocessing
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
"portpicker",
Expand Down Expand Up @@ -100,7 +100,7 @@ gcp = [
# Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install.
tpu = [
"axlearn[gcp]",
"jax[tpu]==0.4.30", # must be >=0.4.19 for compat with v5p.
"jax[tpu]==0.4.33", # must be >=0.4.19 for compat with v5p.
]
# Vertex AI tensorboard.
vertexai_tensorboard = [
Expand All @@ -124,7 +124,7 @@ dataflow = [
# GPU custom kernel dependency.
gpu = [
"triton==2.1.0",
"jax[cuda12_pip]==0.4.30",
"jax[cuda12_pip]==0.4.33",
]
# Open API inference.
open_api = [
Expand Down