unsloth微调gemma3图文代码简析
目录
unsloth微调gemma3图文代码简析
代码使用了unsloth gemma3-4B的微调示例。
加载本地已经下载好的模型,使用了bnb 4bit量化,加载方便。
# 用户部分代码
model, processor = FastVisionModel.from_pretrained(
model_name = "/data/……/……/unsloth/gemma-3-4b-it-bnb-4bit",
load_in_4bit = True, # 4 bit quantization to reduce memory
)
unsloth 加载模型FastVisionModel.from_pretrained函数逻辑:
# unsloth FastVisionModel.from_pretrained函数检查模型是否包含vision模块
model_config = AutoConfig.from_pretrained(
model_name,
token = token,
trust_remote_code = trust_remote_code,
)
……
# Check if VLM
is_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures)
is_vlm = is_vlm or hasattr(model_config, "vision_config")
if auto_model is None:
auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
model, tokenizer = FastBaseModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = _get_dtype(dtype),
load_in_4bit = load_in_4bit,
load_in_8bit = load_in_8bit,
full_finetuning = full_finetuning,
token = token,
device_map = device_map,
trust_remote_code = trust_remote_code,
revision = revision if not is_peft else None,
model_types = model_types,
tokenizer_name = tokenizer_name,
auto_model = auto_model,
use_gradient_checkpointing = use_gradient_checkpointing,
supports_sdpa = supports_sdpa,
whisper_language = whisper_language,
whisper_task = whisper_task,
*args, **kwargs,
)
在内层FastBaseModel.from_pretrained判断和加载模型,识别出是否需要适配视觉模型处理器
# unsloth FastBaseModel.from_pretrained
# 这行判断是否为为vlm模型,使用对应的处理器加载函数
is_vlm = (auto_model is AutoModelForVision2Seq)
is_whisper = (whisper_language is not None and whisper_task is not None)
auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer
if (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"):
tokenizer = auto_processor.from_pretrained(
tokenizer_name,
padding_side = "right",
token = token,
language = whisper_language,
task = whisper_task,
)
else:
tokenizer = auto_processor.from_pretrained(
tokenizer_name,
padding_side = "right",
token = token,
)
用户配置模型lora参数
# 用户代码部分
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers = True, # Turn off for just text!
finetune_language_layers = True, # Should leave on!
finetune_attention_modules = True, # Attention good for GRPO
finetune_mlp_modules = True, # SHould leave on always!
r = 16, # Larger = higher accuracy, but might overfit
lora_alpha = 16, # Recommended alpha == r at least
lora_dropout = 0,
bias = "none",
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
target_modules = "all-linear", # Optional now! Can specify a list if needed
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
模型lora网络配置加载内部逻辑:
# unsloth FastVisionModel.get_peft_model内部调用函数,在选出一些与开启module训练相关的模块名称
def get_peft_regex(
model,
finetune_vision_layers : bool = True,
finetune_language_layers : bool = True,
finetune_attention_modules : bool = True,
finetune_mlp_modules : bool = True,
target_modules : List[str] = None,
vision_tags : List[str] = ["vision", "image", "visual", "patch",],
language_tags : List[str] = ["language", "text",],
attention_tags : List[str] = ["self_attn", "attention", "attn",],
mlp_tags : List[str] = ["mlp", "feed_forward", "ffn", "dense",],
) -> str:
……
# 在选出一些与开启module训练相关的模块名称
regex_model_parts = []
if finetune_vision_layers: regex_model_parts += vision_tags
if finetune_language_layers: regex_model_parts += language_tags
regex_components = []
if finetune_attention_modules: regex_components += attention_tags
if finetune_mlp_modules: regex_components += mlp_tags
regex_model_parts = "|".join(regex_model_parts)
regex_components = "|".join(regex_components)
之后被选出的训练模块名称放到lora配置
# unsloth 内部函数
lora_config_dict = {
"r" : r,
"lora_alpha" : lora_alpha,
"target_modules" : target_modules, # get_peft_regex 函数的返回
"target_parameters" : kwargs.get("target_parameters", None),
"lora_dropout" : lora_dropout,
"bias" : bias,
"task_type" : task_type,
"use_rslora" : use_rslora,
"init_lora_weights" : init_lora_weights,
"loftq_config" : loftq_config,
}
lora_config = LoraConfig(
**{k:v for k,v in lora_config_dict.items() if k in LoraConfig.__doc__},
)
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing = use_gradient_checkpointing,
)
model = _get_peft_model(model, lora_config)
配置训练module层,给对应的层打开梯度更新,关闭不需要的层
# unsloth prepare_model_for_kbit_training 内部调用函数
def prepare_model_for_training(
model : Any,
use_gradient_checkpointing : Optional = "unsloth",
use_reentrant : Optional[bool] = True,
full_finetuning : Optional[bool] = False,
train_layernorms : Optional[bool] = False,
train_embedding : Optional[bool] = False,
train_lm_head : Optional[bool] = False,
float32_mixed_precision : Optional[bool] = True,
) -> Any:
……
for name, param in model.named_parameters():
upcast = False
requires_grad = False
if not full_finetuning:
if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
upcast = True
requires_grad = True
else:
requires_grad = False
else:
if train_layernorms and ("norm." in name or "_layernorm" in name):
requires_grad = True
upcast = True # Must upcast layernorms to float32
if train_embedding and ("embed_tokens" in name or "embedding" in name):
requires_grad = True
upcast = False # Can leave in bfloat16
if train_lm_head and ("lm_head" in name):
requires_grad = True
upcast = False # Can leave in bfloat16
else:
requires_grad = True
upcast = False # Can leave in bfloat16
pass
# Set training or not
if requires_grad:
param.requires_grad_(True)
else:
param.requires_grad_(False)
# Upcast to float32 if needed
if requires_grad:
name = name.replace("base_model", "model", 1)
while re.search(r'\.(\d+)\.', name) is not None:
name = re.sub(r'\.(\d+)\.', r'[\1].', name)
name = name.replace(".weight", "", 1)
dtype = torch.float32 if upcast else mixed_precision_dtype
try:
# Try original name
exec(f"{name}.to({str(dtype)})")
except:
# Maybe model.model
exec(f"model.{name}.to({str(dtype)})")
pass
if ('norm.' in name or '_layernorm' in name) and os.environ.get("UNSLOTH_UPCAST_LAYERNORM", "0") == "1":
try:
name = name.replace("base_model", "model", 1)
while re.search(r'\.(\d+)\.', name) is not None:
name = re.sub(r'\.(\d+)\.', r'[\1].', name)
name = name.replace(".weight", "", 1)
# Try original name
exec(f"{name}.to({str(torch.float32)})")
except:
# Maybe model.model
exec(f"model.{name}.to({str(torch.float32)})")
加载huggingface的图文训练数据集
# 用户代码部分
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
return { "text" : texts, }
def convert_to_conversation(sample):
instruction = "Write the LaTeX representation for this image."
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": instruction},
{"type": "image", "image": sample["image"]},
],
},
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
]
return {"messages": conversation}
……
……
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
converted_dataset = [convert_to_conversation(sample) for sample in dataset]
processor = get_chat_template(
processor,
"gemma-3"
)
根据gemma-3类型判断返回对话模版
# unsloth 内部函数
gemma3_ollama =
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if or (eq .Role "user") (eq .Role "system") }}<start_of_turn>user
{{ .Content }}<end_of_turn>
{{ if $last }}<start_of_turn>model
{{ end }}
{{- else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}{{ if not $last }}<end_of_turn>
{{ end }}
{{- end }}
{{- end }}"""
PARAMETER stop "<end_of_turn>"
PARAMETER stop "<eos>"
PARAMETER temperature 0.1
PARAMETER min_p 0.0
PARAMETER top_k 64
PARAMETER top_p 0.95
PARAMETER num_predict 32768
'''
gemma3_template_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)
# get_chat_template 内部
def get_chat_template()
chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
…… # 还有其余pad相关处理,主要获取模板
开启训练模式,与训练配置
# 用户代码部分
FastVisionModel.for_training(model) # Enable for training!
trainer = SFTTrainer(
model=model,
train_dataset=converted_dataset,
processing_class=processor.tokenizer,
data_collator=UnslothVisionDataCollator(model, processor),
args = SFTConfig(
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4,
gradient_checkpointing = True,
# use reentrant checkpointing
gradient_checkpointing_kwargs = {"use_reentrant": False},
max_grad_norm = 0.3, # max gradient norm based on QLoRA paper
warmup_ratio = 0.03,
max_steps = 3,
#num_train_epochs = 2, # Set this instead of max_steps for full training runs
learning_rate = 2e-4,
logging_steps = 1,
save_strategy="steps",
optim = "adamw_torch_fused",
weight_decay = 0.01,
lr_scheduler_type = "cosine",
seed = 3407,
output_dir = "outputs",
report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
max_length = 2048,
)
)
trainer_stats = trainer.train()
保存训练好的lora模型
# 用户代码部分
model.save_pretrained("gemmavision-3",'/……/……/testlora') # Local saving
processor.save_pretrained("gemmavision-3")