Complete guide for optimizing JAX and Flax workflows on the Ruqola server’s H200 GPUs.
# CUDA 12.1 compatible JAX (recommended for H200)
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Additional packages
pip install flax optax orbax-checkpoint
pip install chex ml_collections wandb
# Verify installation
python -c "import jax; print('JAX version:', jax.__version__); print('Devices:', jax.devices())"
# Add to ~/.bashrc or job script
export CUDA_VISIBLE_DEVICES=0 # Use first H200, or 0,1,2 for all
export XLA_PYTHON_CLIENT_PREALLOCATE=false # Dynamic memory allocation
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda
import jax
import jax.numpy as jnp
from jax import devices
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {devices()}")
print(f"Default device: {jax.devices()[0]}")
# Check H200 capabilities
for i, device in enumerate(devices()):
print(f"Device {i}: {device}")
print(f" Device kind: {device.device_kind}")
print(f" Platform: {device.platform}")
# Simple computation test
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
result = jnp.dot(x, y)
print(f"Dot product result: {result} (computed on {result.device()})")
import jax
import jax.numpy as jnp
from jax import device_put, devices
# Move data to specific GPU
def move_to_device(data, device_id=0):
device = devices()[device_id]
return device_put(data, device)
# Basic operations
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([[5.0, 6.0], [7.0, 8.0]])
# Matrix multiplication on GPU
result = jnp.dot(x, y)
print(f"Result device: {result.device()}")
# Element-wise operations
z = x + y
w = jnp.sin(x) * jnp.cos(y)
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random
def init_network_params(layer_sizes, key):
"""Initialize neural network parameters"""
keys = random.split(key, len(layer_sizes))
params = []
for i in range(len(layer_sizes) - 1):
key = keys[i]
W_key, b_key = random.split(key)
# Xavier initialization
fan_in, fan_out = layer_sizes[i], layer_sizes[i + 1]
W = random.normal(W_key, (fan_in, fan_out)) * jnp.sqrt(2.0 / fan_in)
b = jnp.zeros(fan_out)
params.append((W, b))
return params
def forward_pass(params, x):
"""Forward pass through network"""
for W, b in params[:-1]:
x = jnp.maximum(0, jnp.dot(x, W) + b) # ReLU activation
# Final layer (no activation)
W, b = params[-1]
return jnp.dot(x, W) + b
# Initialize model
key = random.PRNGKey(42)
layer_sizes = [784, 512, 256, 10] # MNIST example
params = init_network_params(layer_sizes, key)
# Test forward pass
x_test = random.normal(key, (32, 784)) # Batch of 32 samples
output = forward_pass(params, x_test)
print(f"Output shape: {output.shape}, Device: {output.device()}")
import optax
def loss_fn(params, batch_x, batch_y):
"""Compute cross-entropy loss"""
logits = forward_pass(params, batch_x)
return optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
def accuracy(params, batch_x, batch_y):
"""Compute classification accuracy"""
logits = forward_pass(params, batch_x)
return jnp.mean(jnp.argmax(logits, -1) == batch_y)
# Compile functions for speed
loss_fn_jit = jit(loss_fn)
accuracy_jit = jit(accuracy)
# Gradient function
grad_fn = jit(grad(loss_fn))
# Optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
# Training step
@jit
def train_step(params, opt_state, batch_x, batch_y):
grads = grad_fn(params, batch_x, batch_y)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Training loop
for epoch in range(10):
for batch_x, batch_y in dataloader: # Your dataloader here
params, opt_state = train_step(params, opt_state, batch_x, batch_y)
# Evaluate
train_loss = loss_fn_jit(params, batch_x, batch_y)
train_acc = accuracy_jit(params, batch_x, batch_y)
print(f"Epoch {epoch+1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
import os
import jax
# Enable dynamic memory allocation (recommended for H200)
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Alternative: Set memory fraction
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9' # Use 90% of GPU memory
# Check memory usage
def print_memory_usage():
try:
# This requires jax-cuda extensions
import jax.lib.xla_bridge as xb
backend = xb.get_backend('gpu')
print(f"GPU memory usage: {backend.get_stats()}")
except:
print("Memory stats not available")
from jax.experimental import remat
def create_checkpointed_model(layers):
"""Create model with gradient checkpointing"""
@remat
def checkpointed_layer(x, params):
W, b = params
return jnp.maximum(0, jnp.dot(x, W) + b)
def forward_pass_checkpointed(params, x):
for layer_params in params[:-1]:
x = checkpointed_layer(x, layer_params)
# Final layer without checkpointing
W, b = params[-1]
return jnp.dot(x, W) + b
return forward_pass_checkpointed
# Use checkpointed model for memory efficiency
checkpointed_forward = create_checkpointed_model(params)
import jax.numpy as jnp
from jax import grad, jit
def mixed_precision_forward(params, x):
"""Forward pass with mixed precision"""
# Convert to half precision for computation
x = x.astype(jnp.float16)
for W, b in params[:-1]:
W = W.astype(jnp.float16)
b = b.astype(jnp.float16)
x = jnp.maximum(0, jnp.dot(x, W) + b)
# Final layer in float32 for numerical stability
W, b = params[-1]
x = x.astype(jnp.float32)
W = W.astype(jnp.float32)
b = b.astype(jnp.float32)
return jnp.dot(x, W) + b
def mixed_precision_loss(params, batch_x, batch_y):
"""Loss function with mixed precision"""
logits = mixed_precision_forward(params, batch_x)
return optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
# Use mixed precision in training
mp_grad_fn = jit(grad(mixed_precision_loss))
def gradient_accumulation_step(params, opt_state, batches, accumulation_steps):
"""Accumulate gradients over multiple batches"""
def compute_grads(params, batch):
batch_x, batch_y = batch
return grad_fn(params, batch_x, batch_y)
# Compute gradients for each mini-batch
all_grads = [compute_grads(params, batch) for batch in batches]
# Average gradients
def average_grads(grad_list):
if isinstance(grad_list[0], (list, tuple)):
return [average_grads([g[i] for g in grad_list])
for i in range(len(grad_list[0]))]
else:
return jnp.mean(jnp.stack(grad_list), axis=0)
averaged_grads = average_grads(all_grads)
# Apply updates
updates, opt_state = optimizer.update(averaged_grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
# Use gradient accumulation for large effective batch sizes
accumulation_steps = 4
mini_batches = [next(dataloader) for _ in range(accumulation_steps)]
params, opt_state = gradient_accumulation_step(params, opt_state, mini_batches, accumulation_steps)
import jax
from jax import jit
from functools import partial
# Aggressive JIT compilation for performance
@partial(jit, static_argnums=(2,)) # Make certain arguments static
def optimized_forward_pass(params, x, num_classes):
"""Highly optimized forward pass"""
# Ensure optimal tensor shapes for H200
x = jnp.reshape(x, (-1, x.shape[-1])) # Flatten if needed
for W, b in params[:-1]:
# Fused operations for better performance
x = jax.nn.relu(jnp.dot(x, W) + b)
W, b = params[-1]
logits = jnp.dot(x, W) + b
return jnp.reshape(logits, (-1, num_classes))
# Compile training step with static arguments
@partial(jit, static_argnums=(3,))
def optimized_train_step(params, opt_state, batch, num_classes):
batch_x, batch_y = batch
def loss_fn_inner(p):
logits = optimized_forward_pass(p, batch_x, num_classes)
return optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
loss, grads = jax.value_and_grad(loss_fn_inner)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
import numpy as np
from functools import partial
def create_data_iterator(dataset, batch_size, shuffle=True):
"""Create optimized data iterator for H200"""
indices = np.arange(len(dataset))
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0, len(indices), batch_size):
end_idx = min(start_idx + batch_size, len(indices))
batch_indices = indices[start_idx:end_idx]
# Load batch and move to GPU
batch_x = jnp.array([dataset[i][0] for i in batch_indices])
batch_y = jnp.array([dataset[i][1] for i in batch_indices])
yield batch_x, batch_y
# Prefetch data to overlap computation and data loading
def prefetch_iterator(iterator, prefetch_size=2):
"""Prefetch batches to overlap with computation"""
import threading
from queue import Queue
queue = Queue(maxsize=prefetch_size)
def producer():
for batch in iterator:
queue.put(batch)
queue.put(None) # Sentinel
thread = threading.Thread(target=producer)
thread.start()
while True:
batch = queue.get()
if batch is None:
break
yield batch
thread.join()
def optimized_attention(query, key, value, mask=None):
"""Optimized attention mechanism for H200"""
# Ensure tensor dimensions are optimal for Tensor Cores
d_k = query.shape[-1]
# Compute attention scores with fused operations
scores = jnp.matmul(query, key.transpose(-2, -1)) / jnp.sqrt(d_k)
if mask is not None:
scores = jnp.where(mask, scores, jnp.finfo(scores.dtype).min)
# Softmax with numerical stability
attention_weights = jax.nn.softmax(scores, axis=-1)
# Apply attention to values
output = jnp.matmul(attention_weights, value)
return output, attention_weights
# Use optimal tensor shapes (multiples of 8 for H200 Tensor Cores)
def ensure_optimal_shapes(x, target_multiple=8):
"""Pad tensors to optimal shapes for H200"""
shape = x.shape
new_shape = []
for dim in shape:
if dim % target_multiple != 0:
new_dim = ((dim // target_multiple) + 1) * target_multiple
else:
new_dim = dim
new_shape.append(new_dim)
if new_shape != list(shape):
pad_width = [(0, new_dim - old_dim) for new_dim, old_dim in zip(new_shape, shape)]
x = jnp.pad(x, pad_width, mode='constant')
return x
import jax
from jax import pmap, devices, device_put_replicated
import jax.numpy as jnp
# Check available devices
num_devices = len(devices())
print(f"Available devices: {num_devices}")
def replicate_params(params):
"""Replicate parameters across all devices"""
return device_put_replicated(params, devices())
# Parallel training step
@pmap
def parallel_train_step(params, opt_state, batch):
batch_x, batch_y = batch
def loss_fn(p):
logits = forward_pass(p, batch_x)
return optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# Replicate initial state
replicated_params = replicate_params(params)
replicated_opt_state = device_put_replicated(opt_state, devices())
# Create parallel data batches
def create_parallel_batches(batch, num_devices):
"""Split batch across devices"""
batch_x, batch_y = batch
batch_size = batch_x.shape[0]
# Ensure batch size is divisible by number of devices
per_device_batch_size = batch_size // num_devices
parallel_x = batch_x[:per_device_batch_size * num_devices]
parallel_y = batch_y[:per_device_batch_size * num_devices]
parallel_x = parallel_x.reshape(num_devices, per_device_batch_size, *batch_x.shape[1:])
parallel_y = parallel_y.reshape(num_devices, per_device_batch_size, *batch_y.shape[1:])
return parallel_x, parallel_y
# Training loop with pmap
for epoch in range(num_epochs):
for batch in dataloader:
parallel_batch = create_parallel_batches(batch, num_devices)
replicated_params, replicated_opt_state, losses = parallel_train_step(
replicated_params, replicated_opt_state, parallel_batch
)
# Average loss across devices
avg_loss = jnp.mean(losses)
print(f"Average loss: {avg_loss:.4f}")
import jax
from jax.experimental import mesh_utils, PartitionSpec
from jax.experimental.pjit import pjit
# Create device mesh for model parallelism
devices_array = mesh_utils.create_device_mesh((1, 3)) # 1x3 for 3 H200s
mesh = jax.experimental.Mesh(devices_array, ('data', 'model'))
def create_sharded_params(layer_sizes, key, mesh):
"""Create parameters sharded across model dimension"""
with mesh:
params = []
keys = random.split(key, len(layer_sizes))
for i in range(len(layer_sizes) - 1):
W_key, b_key = random.split(keys[i])
# Shard weight matrix along output dimension
W = random.normal(W_key, (layer_sizes[i], layer_sizes[i + 1]))
W = jax.device_put(W, PartitionSpec(None, 'model'))
# Replicate bias
b = jnp.zeros(layer_sizes[i + 1])
b = jax.device_put(b, PartitionSpec('model'))
params.append((W, b))
return params
# Sharded forward pass
@pjit(
in_shardings=(PartitionSpec(None, 'model'), PartitionSpec('data', None)),
out_shardings=PartitionSpec('data', None)
)
def sharded_forward_pass(params, x):
"""Forward pass with model parallelism"""
for W, b in params[:-1]:
x = jnp.maximum(0, jnp.dot(x, W) + b)
W, b = params[-1]
return jnp.dot(x, W) + b
# Use model parallelism
with mesh:
sharded_params = create_sharded_params(layer_sizes, key, mesh)
# Training step with sharding
@pjit(
in_shardings=(PartitionSpec(None, 'model'), None, PartitionSpec('data', None), PartitionSpec('data')),
out_shardings=(PartitionSpec(None, 'model'), None, None)
)
def sharded_train_step(params, opt_state, batch_x, batch_y):
def loss_fn(p):
logits = sharded_forward_pass(p, batch_x)
return optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# Submit multi-GPU JAX job
gpuq submit \
--command "python train_jax_multi.py --num-devices=3 --model-parallel" \
--gpus 3 \
--memory 80 \
--time 16
import flax.linen as nn
from flax.training import train_state
import optax
class LargeTransformer(nn.Module):
vocab_size: int = 50000
d_model: int = 4096
num_heads: int = 32
num_layers: int = 24
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x, training=False):
# Token and position embeddings
x = nn.Embed(self.vocab_size, self.d_model)(x)
x = x + self.get_positional_encoding(x.shape[1], self.d_model)
# Dropout
x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x)
# Transformer layers with checkpointing
for _ in range(self.num_layers):
x = TransformerBlock(
d_model=self.d_model,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate
)(x, training=training)
# Output projection
x = nn.Dense(self.vocab_size)(x)
return x
def get_positional_encoding(self, seq_len, d_model):
"""Sinusoidal positional encoding"""
pos = jnp.arange(seq_len)[:, None]
i = jnp.arange(d_model)[None, :]
angle_rates = 1 / jnp.power(10000, (2 * (i // 2)) / d_model)
angle_rads = pos * angle_rates
# Apply sin to even indices
angle_rads = angle_rads.at[:, 0::2].set(jnp.sin(angle_rads[:, 0::2]))
# Apply cos to odd indices
angle_rads = angle_rads.at[:, 1::2].set(jnp.cos(angle_rads[:, 1::2]))
return angle_rads[None, ...]
class TransformerBlock(nn.Module):
d_model: int
num_heads: int
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x, training=False):
# Multi-head attention with residual connection
attn_output = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
deterministic=not training
)(x, x)
x = nn.LayerNorm()(x + attn_output)
# Feed-forward network with residual connection
ff_output = nn.Dense(self.d_model * 4)(x)
ff_output = nn.gelu(ff_output)
ff_output = nn.Dense(self.d_model)(ff_output)
ff_output = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(ff_output)
x = nn.LayerNorm()(x + ff_output)
return x
# Initialize large model
model = LargeTransformer()
key = jax.random.PRNGKey(42)
dummy_input = jnp.ones((1, 512), dtype=jnp.int32)
# Initialize parameters
variables = model.init(key, dummy_input, training=False)
params = variables['params']
print(f"Model parameters: {sum(x.size for x in jax.tree_util.tree_leaves(params))}")
import orbax.checkpoint as ocp
from flax.training import train_state
def create_train_state(model, learning_rate, weight_decay):
"""Create training state with optimizer"""
# AdamW optimizer with weight decay
optimizer = optax.adamw(
learning_rate=learning_rate,
weight_decay=weight_decay
)
return train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer
)
# Checkpointing manager
checkpoint_manager = ocp.CheckpointManager(
'checkpoints/',
ocp.PyTreeCheckpointer(),
options=ocp.CheckpointManagerOptions(
save_interval_steps=1000,
max_to_keep=3,
),
)
# Training step with gradient clipping
@jit
def train_step_with_clipping(state, batch, dropout_rng):
def loss_fn(params):
logits = model.apply(
{'params': params},
batch['input_ids'],
training=True,
rngs={'dropout': dropout_rng}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits[:, :-1], batch['labels'][:, 1:]
).mean()
return loss
loss, grads = jax.value_and_grad(loss_fn)(state.params)
# Gradient clipping
grads = jax.tree_util.tree_map(
lambda g: jnp.clip(g, -1.0, 1.0), grads
)
state = state.apply_gradients(grads=grads)
return state, loss
# Large model training loop
state = create_train_state(model, learning_rate=1e-4, weight_decay=0.01)
for step in range(num_training_steps):
dropout_rng, rng = jax.random.split(rng)
batch = next(dataloader)
state, loss = train_step_with_clipping(state, batch, dropout_rng)
if step % 100 == 0:
print(f"Step {step}, Loss: {loss:.4f}")
# Save checkpoint
if step % 1000 == 0:
checkpoint_manager.save(step, state)
import optax
from typing import NamedTuple
class AdaptiveGradientState(NamedTuple):
"""State for adaptive gradient scaling"""
step: int
grad_norm_history: jnp.ndarray
scale_factor: float
def adaptive_gradient_scaling(window_size=100, target_norm=1.0):
"""Custom gradient transformation with adaptive scaling"""
def init_fn(params):
return AdaptiveGradientState(
step=0,
grad_norm_history=jnp.zeros(window_size),
scale_factor=1.0
)
def update_fn(updates, state, params=None):
# Compute gradient norm
grad_norm = optax.global_norm(updates)
# Update history
new_history = state.grad_norm_history.at[state.step % window_size].set(grad_norm)
avg_grad_norm = jnp.mean(new_history)
# Compute adaptive scale factor
scale_factor = target_norm / (avg_grad_norm + 1e-8)
scale_factor = jnp.clip(scale_factor, 0.1, 10.0)
# Scale updates
scaled_updates = jax.tree_util.tree_map(lambda x: x * scale_factor, updates)
new_state = AdaptiveGradientState(
step=state.step + 1,
grad_norm_history=new_history,
scale_factor=scale_factor
)
return scaled_updates, new_state
return optax.GradientTransformation(init_fn, update_fn)
# Use custom optimizer
optimizer = optax.chain(
adaptive_gradient_scaling(),
optax.adam(1e-3)
)
import jax
import tensorflow as tf
def create_tf_data_pipeline(file_pattern, batch_size, seq_len):
"""Create optimized TensorFlow data pipeline for JAX"""
def parse_example(example_proto):
features = {
'input_ids': tf.io.FixedLenFeature([seq_len], tf.int64),
'labels': tf.io.FixedLenFeature([seq_len], tf.int64),
}
return tf.io.parse_single_example(example_proto, features)
# Create dataset
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=True)
dataset = dataset.interleave(
tf.data.TFRecordDataset,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False
)
# Parse and batch
dataset = dataset.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# Convert TF dataset to JAX arrays
def tf_to_jax_iterator(tf_dataset):
"""Convert TensorFlow dataset to JAX iterator"""
for batch in tf_dataset:
jax_batch = {
'input_ids': jnp.array(batch['input_ids'].numpy()),
'labels': jnp.array(batch['labels'].numpy()),
}
yield jax_batch
def focal_loss(logits, labels, alpha=0.25, gamma=2.0):
"""Focal loss for handling class imbalance"""
ce_loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
p_t = jnp.exp(-ce_loss) # Probability of true class
focal_weight = alpha * (1 - p_t) ** gamma
focal_loss = focal_weight * ce_loss
return focal_loss.mean()
def label_smoothing_loss(logits, labels, smoothing=0.1):
"""Label smoothing cross-entropy loss"""
num_classes = logits.shape[-1]
smooth_labels = (1 - smoothing) * jax.nn.one_hot(labels, num_classes)
smooth_labels += smoothing / num_classes
log_probs = jax.nn.log_softmax(logits)
loss = -jnp.sum(smooth_labels * log_probs, axis=-1)
return loss.mean()
def contrastive_loss(embeddings1, embeddings2, labels, margin=1.0):
"""Contrastive loss for similarity learning"""
euclidean_distance = jnp.linalg.norm(embeddings1 - embeddings2, axis=1)
# Loss for positive pairs (label = 1)
positive_loss = labels * jnp.square(euclidean_distance)
# Loss for negative pairs (label = 0)
negative_loss = (1 - labels) * jnp.square(jnp.maximum(0, margin - euclidean_distance))
return jnp.mean(positive_loss + negative_loss) / 2
import jax
import jax.profiler
def profile_training_step():
"""Profile training step performance"""
# Start profiling
jax.profiler.start_trace("/tmp/jax_trace")
# Your training code
for step in range(100):
batch = next(dataloader)
state, loss = train_step(state, batch)
if step == 50: # Profile middle section
jax.profiler.start_server(9999) # Start profiler server
# Stop profiling
jax.profiler.stop_trace()
# Use JAX debugging utilities
def debug_shapes_and_dtypes(pytree, name="PyTree"):
"""Debug tensor shapes and dtypes"""
def print_info(path, x):
print(f"{name}.{'.'.join(map(str, path))}: shape={x.shape}, dtype={x.dtype}")
jax.tree_util.tree_map_with_path(print_info, pytree)
# Check for NaN/Inf values
def check_for_nans(pytree, step_name=""):
"""Check for NaN or Inf values in pytree"""
def has_nan_or_inf(x):
return jnp.any(jnp.isnan(x)) or jnp.any(jnp.isinf(x))
nan_mask = jax.tree_util.tree_map(has_nan_or_inf, pytree)
if any(jax.tree_util.tree_leaves(nan_mask)):
print(f"WARNING: NaN/Inf detected at {step_name}")
debug_shapes_and_dtypes(pytree, step_name)
import psutil
import threading
import time
def memory_monitor(interval=5):
"""Monitor system and GPU memory usage"""
def monitor():
while True:
# System memory
memory = psutil.virtual_memory()
print(f"System RAM: {memory.percent:.1f}% used, {memory.available/1024**3:.1f}GB available")
# JAX memory (if available)
try:
devices = jax.devices()
for i, device in enumerate(devices):
if device.device_kind == 'gpu':
# This is a simplified version - actual memory tracking in JAX is limited
print(f"GPU {i}: {device}")
except Exception as e:
print(f"GPU memory info not available: {e}")
time.sleep(interval)
thread = threading.Thread(target=monitor, daemon=True)
thread.start()
return thread
# Start memory monitoring
memory_thread = memory_monitor(interval=10)
#!/usr/bin/env python3
"""
H200-Optimized JAX/Flax Training Script
Usage: gpuq submit --command "python train_jax_h200.py --config config.py" --gpus 1 --memory 100
"""
import jax
import jax.numpy as jnp
from jax import random, grad, jit, devices
import flax.linen as nn
from flax.training import train_state
import optax
import orbax.checkpoint as ocp
import argparse
import importlib.util
import logging
from pathlib import Path
def setup_logging():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training.log'),
logging.StreamHandler()
]
)
def load_config(config_path):
"""Load configuration from Python file"""
spec = importlib.util.spec_from_file_location("config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
def create_train_state(model, params, config):
"""Create training state"""
if config.optimizer_name == 'adamw':
optimizer = optax.adamw(
learning_rate=config.learning_rate,
weight_decay=config.weight_decay
)
elif config.optimizer_name == 'adam':
optimizer = optax.adam(learning_rate=config.learning_rate)
else:
raise ValueError(f"Unknown optimizer: {config.optimizer_name}")
# Add gradient clipping if specified
if hasattr(config, 'gradient_clip') and config.gradient_clip > 0:
optimizer = optax.chain(
optax.clip_by_global_norm(config.gradient_clip),
optimizer
)
return train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer
)
@jit
def train_step(state, batch, dropout_rng):
"""Compiled training step"""
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
batch['input'],
training=True,
rngs={'dropout': dropout_rng}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits, batch['target']
).mean()
return loss, logits
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
state = state.apply_gradients(grads=grads)
# Compute accuracy
accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['target'])
return state, loss, accuracy
@jit
def eval_step(state, batch):
"""Compiled evaluation step"""
logits = state.apply_fn({'params': state.params}, batch['input'], training=False)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['target']).mean()
accuracy = jnp.mean(jnp.argmax(logits, -1) == batch['target'])
return loss, accuracy
def evaluate_model(state, eval_dataloader):
"""Evaluate model on validation set"""
total_loss = 0.0
total_accuracy = 0.0
num_batches = 0
for batch in eval_dataloader:
loss, accuracy = eval_step(state, batch)
total_loss += loss
total_accuracy += accuracy
num_batches += 1
return total_loss / num_batches, total_accuracy / num_batches
def main():
parser = argparse.ArgumentParser(description='JAX/Flax Training Script')
parser.add_argument('--config', required=True, help='Configuration file')
parser.add_argument('--resume', help='Resume from checkpoint')
args = parser.parse_args()
setup_logging()
config = load_config(args.config)
# Setup JAX
logging.info(f"JAX devices: {devices()}")
# Create model
if config.model_type == 'transformer':
model = LargeTransformer(**config.model_params)
else:
raise ValueError(f"Unknown model type: {config.model_type}")
# Initialize parameters
key = random.PRNGKey(config.seed)
init_key, train_key = random.split(key)
dummy_input = jnp.ones((1, config.sequence_length), dtype=jnp.int32)
variables = model.init(init_key, dummy_input, training=False)
params = variables['params']
# Count parameters
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
logging.info(f"Model parameters: {param_count:,}")
# Create training state
state = create_train_state(model, params, config)
# Setup checkpointing
checkpoint_manager = ocp.CheckpointManager(
config.checkpoint_dir,
ocp.PyTreeCheckpointer(),
options=ocp.CheckpointManagerOptions(
save_interval_steps=config.save_interval,
max_to_keep=3,
),
)
# Resume from checkpoint if specified
if args.resume:
state = checkpoint_manager.restore(args.resume)
logging.info(f"Resumed from checkpoint: {args.resume}")
# Create data loaders (implement based on your data)
train_dataloader = create_dataloader(config, split='train')
eval_dataloader = create_dataloader(config, split='eval')
# Training loop
step = 0
for epoch in range(config.num_epochs):
logging.info(f"Starting epoch {epoch+1}/{config.num_epochs}")
# Training phase
epoch_loss = 0.0
epoch_accuracy = 0.0
num_batches = 0
for batch in train_dataloader:
# Split RNG for dropout
train_key, dropout_key = random.split(train_key)
# Training step
state, loss, accuracy = train_step(state, batch, dropout_key)
epoch_loss += loss
epoch_accuracy += accuracy
num_batches += 1
step += 1
# Logging
if step % config.log_interval == 0:
logging.info(f"Step {step}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
# Evaluation
if step % config.eval_interval == 0:
eval_loss, eval_accuracy = evaluate_model(state, eval_dataloader)
logging.info(f"Evaluation - Loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.4f}")
# Checkpointing
if step % config.save_interval == 0:
checkpoint_manager.save(step, state)
logging.info(f"Checkpoint saved at step {step}")
# Epoch summary
avg_loss = epoch_loss / num_batches
avg_accuracy = epoch_accuracy / num_batches
logging.info(f"Epoch {epoch+1} completed - Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")
# Final checkpoint
checkpoint_manager.save(step, state)
logging.info("Training completed!")
if __name__ == '__main__':
main()
# JAX/Flax Training Configuration
import jax.numpy as jnp
# Model configuration
model_type = 'transformer'
model_params = {
'vocab_size': 32000,
'd_model': 1024,
'num_heads': 16,
'num_layers': 12,
'dropout_rate': 0.1,
}
# Training configuration
num_epochs = 10
batch_size = 32
sequence_length = 512
learning_rate = 1e-4
weight_decay = 0.01
gradient_clip = 1.0
# Optimizer
optimizer_name = 'adamw'
# Data configuration
train_data_path = '/path/to/train_data'
eval_data_path = '/path/to/eval_data'
# Checkpointing
checkpoint_dir = './checkpoints'
save_interval = 1000
log_interval = 100
eval_interval = 500
# Miscellaneous
seed = 42
Next Steps: