-
Notifications
You must be signed in to change notification settings - Fork 195
Expand file tree
/
Copy pathdiffusion_gemma_sft.yaml
More file actions
137 lines (122 loc) · 4.1 KB
/
Copy pathdiffusion_gemma_sft.yaml
File metadata and controls
137 lines (122 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ============================================================================
# DiffusionGemma (26B-A4B) block-diffusion SFT — full fine-tune on GSM8K.
#
# DiffusionGemma is a block-diffusion model: a causal encoder reads the clean
# prompt+response, and a bidirectional decoder denoises a "canvas" (the response
# region), trained with D3PM uniform-random token corruption (no [MASK]). This
# recipe runs single-turn SFT: FSDP2 + expert parallelism (EP=8), mixed precision
# (fp32 master weights + bf16 compute), canvas length 256, frozen router, and
# two-pass self-conditioning.
#
# 8 GPUs (EP=8). Launch with:
# torchrun --standalone --nproc-per-node=8 \
# examples/dllm_sft/finetune.py -c examples/dllm_sft/diffusion_gemma_sft.yaml
#
# GSM8K is consumed in OpenAI chat-messages format; generate the JSONL once with:
# python examples/dllm_sft/prep_gsm8k.py # writes ./gsm8k_chat_train.jsonl
# ============================================================================
recipe: DiffusionGemmaSFTRecipe
step_scheduler:
global_batch_size: 8
local_batch_size: 1
log_remote_every_steps: 5
ckpt_every_steps: 400
val_every_steps: 100000 # no validation split
max_steps: 800
num_epochs: 1 # GSM8K train ~7.5k ex / global_batch 8 = ~934 steps/epoch; cap at 800 (<1 epoch)
dist_env:
backend: nccl
timeout_minutes: 10
seed: 42
wandb:
enable: false
project: diffusiongemma-sft
model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: google/diffusiongemma-26B-A4B-it
torch_dtype: float32 # fp32 master weights; compute is bf16 via mp_policy
canvas_length: 256
self_conditioning: true
freeze_router: true
backend:
_target_: nemo_automodel.components.models.common.BackendConfig
attn: sdpa
linear: torch
rms_norm: torch_fp32
experts: torch_mm
dispatcher: torch
enable_hf_state_dict_adapter: true
enable_fsdp_optimizations: true
loss_fn:
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
checkpoint:
enabled: true
checkpoint_dir: dllm_checkpoints/diffusion_gemma_sft/
model_save_format: safetensors
save_consolidated: false
distributed:
strategy: fsdp2
dp_size: none
tp_size: 1
cp_size: 1
pp_size: 1
ep_size: 8
sequence_parallel: false
activation_checkpointing: true
mp_policy: # mixed precision: fp32 master + bf16 compute
param_dtype: bfloat16
reduce_dtype: float32
output_dtype: float32
autocast_dtype: bfloat16
offload_policy: null
moe:
reshard_after_forward: false
dllm:
mode: block_diffusion
block_size: 256
vocab_size: 262144
eps: 0.001
pad_block_size: 256
pad_seq_len_divisible: 256
optimizer:
_target_: torch.optim.AdamW
betas: [0.95, 0.99]
eps: 1.0e-8
lr: 1.5e-4
weight_decay: 1.0e-4
clip_grad_norm:
max_norm: 1.0
lr_scheduler:
lr_warmup_steps: 25
init_lr: 0.0
lr_decay_style: cosine
min_lr: 1.5e-5
dataset:
_target_: nemo_automodel.components.datasets.llm.chat_dataset.ChatDataset
path_or_dataset_id: gsm8k_chat_train.jsonl # generate with examples/dllm_sft/prep_gsm8k.py
split: train
shuffle_seed: 42
seq_length: 1024
truncation: true
unshifted: true
mask_history: true # supervise only the final turn (single-turn SFT)
tokenizer:
pretrained_model_name_or_path: google/diffusiongemma-26B-A4B-it
dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_automodel.components.datasets.utils.default_collater
shuffle: true