目录

unsloth微调gemma3图文代码简析

目录

unsloth微调gemma3图文代码简析

代码使用了unsloth gemma3-4B的微调示例。
加载本地已经下载好的模型,使用了bnb 4bit量化,加载方便。

# 用户部分代码
    model, processor = FastVisionModel.from_pretrained(
        model_name = "/data/……/……/unsloth/gemma-3-4b-it-bnb-4bit",
        load_in_4bit = True,  # 4 bit quantization to reduce memory
    )

unsloth 加载模型FastVisionModel.from_pretrained函数逻辑:

        # unsloth FastVisionModel.from_pretrained函数检查模型是否包含vision模块
        model_config = AutoConfig.from_pretrained(
                model_name,
                token = token,
                trust_remote_code = trust_remote_code,
        )
        ……
       # Check if VLM
        is_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures)
        is_vlm = is_vlm or hasattr(model_config, "vision_config")
        if auto_model is None:
            auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM
        model, tokenizer = FastBaseModel.from_pretrained(
            model_name        = model_name,
            max_seq_length    = max_seq_length,
            dtype             = _get_dtype(dtype),
            load_in_4bit      = load_in_4bit,
            load_in_8bit      = load_in_8bit,
            full_finetuning   = full_finetuning,
            token             = token,
            device_map        = device_map,
            trust_remote_code = trust_remote_code,
            revision          = revision if not is_peft else None,
            model_types       = model_types,
            tokenizer_name    = tokenizer_name,
            auto_model        = auto_model,
            use_gradient_checkpointing = use_gradient_checkpointing,
            supports_sdpa     = supports_sdpa,
            whisper_language  = whisper_language,
            whisper_task      = whisper_task,
            *args, **kwargs,
        )

在内层FastBaseModel.from_pretrained判断和加载模型,识别出是否需要适配视觉模型处理器

        # unsloth FastBaseModel.from_pretrained
        # 这行判断是否为为vlm模型,使用对应的处理器加载函数
        is_vlm = (auto_model is AutoModelForVision2Seq)
        is_whisper = (whisper_language is not None and whisper_task is not None)
        auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer
        if (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"):
           tokenizer = auto_processor.from_pretrained(
                tokenizer_name,
                padding_side = "right",
                token        = token,
                language     = whisper_language,
                task         = whisper_task,
            )
        else:
            tokenizer = auto_processor.from_pretrained(
                tokenizer_name,
                padding_side = "right",
                token        = token,
            )

用户配置模型lora参数

# 用户代码部分
model = FastVisionModel.get_peft_model(
        model,
        finetune_vision_layers     = True, # Turn off for just text!
        finetune_language_layers   = True,  # Should leave on!
        finetune_attention_modules = True,  # Attention good for GRPO
        finetune_mlp_modules       = True,  # SHould leave on always!

        r = 16,           # Larger = higher accuracy, but might overfit
        lora_alpha = 16,  # Recommended alpha == r at least
        lora_dropout = 0,
        bias = "none",
        random_state = 3407,
        use_rslora = False,               # We support rank stabilized LoRA
        loftq_config = None,               # And LoftQ
        target_modules = "all-linear",    # Optional now! Can specify a list if needed
        modules_to_save=[
            "lm_head",
            "embed_tokens",
        ],
    )

模型lora网络配置加载内部逻辑:

# unsloth FastVisionModel.get_peft_model内部调用函数,在选出一些与开启module训练相关的模块名称
def get_peft_regex(
    model,
    finetune_vision_layers     : bool = True,
    finetune_language_layers   : bool = True,
    finetune_attention_modules : bool = True,
    finetune_mlp_modules       : bool = True,
    target_modules             : List[str] = None,
    vision_tags                : List[str] = ["vision", "image", "visual", "patch",],
    language_tags              : List[str] = ["language", "text",],
    attention_tags             : List[str] = ["self_attn", "attention", "attn",],
    mlp_tags                   : List[str] = ["mlp", "feed_forward", "ffn", "dense",],
) -> str:
    ……
    # 在选出一些与开启module训练相关的模块名称
    regex_model_parts = []
    if finetune_vision_layers:     regex_model_parts += vision_tags
    if finetune_language_layers:   regex_model_parts += language_tags
    regex_components  = []
    if finetune_attention_modules: regex_components  += attention_tags
    if finetune_mlp_modules:       regex_components  += mlp_tags

    regex_model_parts = "|".join(regex_model_parts)
    regex_components  = "|".join(regex_components)

之后被选出的训练模块名称放到lora配置

   # unsloth 内部函数
   lora_config_dict = {
            "r"                 : r,
            "lora_alpha"        : lora_alpha,
            "target_modules"    : target_modules, # get_peft_regex 函数的返回
            "target_parameters" : kwargs.get("target_parameters", None),
            "lora_dropout"      : lora_dropout,
            "bias"              : bias,
            "task_type"         : task_type,
            "use_rslora"        : use_rslora,
            "init_lora_weights" : init_lora_weights,
            "loftq_config"      : loftq_config,
        }
        lora_config = LoraConfig(
            **{k:v for k,v in lora_config_dict.items() if k in LoraConfig.__doc__},
        )
        model = prepare_model_for_kbit_training(
            model,
            use_gradient_checkpointing = use_gradient_checkpointing,
        )
        model = _get_peft_model(model, lora_config)

配置训练module层,给对应的层打开梯度更新,关闭不需要的层

# unsloth prepare_model_for_kbit_training 内部调用函数
def prepare_model_for_training(
    model                      : Any,
    use_gradient_checkpointing : Optional = "unsloth",
    use_reentrant              : Optional[bool] = True,
    full_finetuning            : Optional[bool] = False,
    train_layernorms           : Optional[bool] = False,
    train_embedding            : Optional[bool] = False,
    train_lm_head              : Optional[bool] = False,
    float32_mixed_precision    : Optional[bool] = True,
) -> Any:
    ……
    for name, param in model.named_parameters():
        upcast = False
        requires_grad = False
        if not full_finetuning:
            if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
                upcast = True
                requires_grad = True
            else:
                requires_grad = False
        else:
            if train_layernorms and ("norm." in name or "_layernorm" in name):
                requires_grad = True
                upcast = True # Must upcast layernorms to float32
            if train_embedding and ("embed_tokens" in name or "embedding" in name):
                requires_grad = True
                upcast = False # Can leave in bfloat16
            if train_lm_head and ("lm_head" in name):
                requires_grad = True
                upcast = False # Can leave in bfloat16
            else:
                requires_grad = True
                upcast = False # Can leave in bfloat16
        pass
        # Set training or not
        if requires_grad:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

        # Upcast to float32 if needed
        if requires_grad:
            name = name.replace("base_model", "model", 1)
            while re.search(r'\.(\d+)\.', name) is not None:
                name = re.sub(r'\.(\d+)\.', r'[\1].', name)
            name = name.replace(".weight", "", 1)
            dtype = torch.float32 if upcast else mixed_precision_dtype
            try:
                # Try original name
                exec(f"{name}.to({str(dtype)})")
            except:
                # Maybe model.model
                exec(f"model.{name}.to({str(dtype)})")
        pass

        if ('norm.' in name or '_layernorm' in name) and os.environ.get("UNSLOTH_UPCAST_LAYERNORM", "0") == "1":
            try:
                name = name.replace("base_model", "model", 1)
                while re.search(r'\.(\d+)\.', name) is not None:
                    name = re.sub(r'\.(\d+)\.', r'[\1].', name)
                name = name.replace(".weight", "", 1)
                # Try original name
                exec(f"{name}.to({str(torch.float32)})")
            except:
                # Maybe model.model
                exec(f"model.{name}.to({str(torch.float32)})")

加载huggingface的图文训练数据集

# 用户代码部分
def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

def convert_to_conversation(sample):
    instruction = "Write the LaTeX representation for this image."
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": instruction},
                {"type": "image", "image": sample["image"]},
            ],
        },
        {"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
    ]
    return {"messages": conversation}
……
……
dataset = load_dataset("unsloth/LaTeX_OCR", split = "train")
    converted_dataset = [convert_to_conversation(sample) for sample in dataset]
    processor = get_chat_template(
        processor,
        "gemma-3"
    )

根据gemma-3类型判断返回对话模版

# unsloth 内部函数
gemma3_ollama = 
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if or (eq .Role "user") (eq .Role "system") }}<start_of_turn>user
{{ .Content }}<end_of_turn>
{{ if $last }}<start_of_turn>model
{{ end }}
{{- else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}{{ if not $last }}<end_of_turn>
{{ end }}
{{- end }}
{{- end }}"""
PARAMETER stop "<end_of_turn>"
PARAMETER stop "<eos>"
PARAMETER temperature 0.1
PARAMETER min_p 0.0
PARAMETER top_k 64
PARAMETER top_p 0.95
PARAMETER num_predict 32768
'''
gemma3_template_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,)

# get_chat_template 内部
def get_chat_template()
  chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
  …… # 还有其余pad相关处理,主要获取模板

开启训练模式,与训练配置

# 用户代码部分
FastVisionModel.for_training(model) # Enable for training!
    trainer = SFTTrainer(
        model=model,
        train_dataset=converted_dataset,
        processing_class=processor.tokenizer,
        data_collator=UnslothVisionDataCollator(model, processor),
        args = SFTConfig(
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 4,
            gradient_checkpointing = True,
            # use reentrant checkpointing
            gradient_checkpointing_kwargs = {"use_reentrant": False},
            max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper
            warmup_ratio = 0.03,
            max_steps = 3,
            #num_train_epochs = 2,          # Set this instead of max_steps for full training runs
            learning_rate = 2e-4,
            logging_steps = 1,
            save_strategy="steps",
            optim = "adamw_torch_fused",
            weight_decay = 0.01,
            lr_scheduler_type = "cosine",
            seed = 3407,
            output_dir = "outputs",
            report_to = "none",             # For Weights and Biases
            # You MUST put the below items for vision finetuning:
            remove_unused_columns = False,
            dataset_text_field = "",
            dataset_kwargs = {"skip_prepare_dataset": True},
            max_length = 2048,
        )
    )
    trainer_stats = trainer.train()

保存训练好的lora模型

# 用户代码部分
    model.save_pretrained("gemmavision-3",'/……/……/testlora')  # Local saving
    processor.save_pretrained("gemmavision-3")