LLaVA 微调教程
写完之后发现他好像不是很需要这个东西,所以就先发在自己的博客好了。不投稿首页或者候选区应该本来也就不会有多少流量,所以应该不会干嘛的,大不了后面被说不让放网上以后就删掉这篇,嘻嘻。
LLaVA 是最早出现的 Vision Language Model。本教程将教你微调 llava-v1.5-13b 。与本博客现有的基于xtuner的微调教程不同,这个教程将使用deepspeed以拜托对书生生态的依赖。
配置环境
配置环境的官方教程即项目ReadMe
首先我们下载LLaVA的源代码
1 | git clone https://github.com/haotian-liu/LLaVA.git |
然后配置Python环境。如果是在自己电脑上运行,请不要忘记创建conda虚拟环境
1 | # conda create -n llava python=3.10 -y |
最后是下载模型。你可以使用huggingface-cli
直接下载模型。如果您所在的区域不能直接访问Hugging Face,则需要使用镜像网站下载
1 | # 如果不能访问Hugging Face,可以执行下面这一行设置使用hf-mirror镜像站下载 HF_ENDPOINT=https://hf-mirror.com |
准备训练数据
官方预训练(训练投影层)使用的数据集是 LAION-CC-SBU,视觉微调使用的数据集是llava_v1_5_mix665k.json和其他一些数据集,在项目Readme中写得特别清楚。但是我并不打算在这里进行介绍或者是重新训练个新模型。我们将简单构造一个只有一张图像构成的简易数据集。
自定义训练数据集的格式要求在这里。
首先我们下载图片:
1 | mkdir -p ./playground/data/yuanshen |
然后准备图文对。这里只准备一个:
1 | import json |
数据集图像为:
模型微调
这一步我们使用 deepspeed zero2 进行模型 LoRA 微调。得到的微调模型会被保存在./checkpoints/llava-v1.5-7b-lora
里。
1 | deepspeed llava/train/train_mem.py \ |
模型训练的脚本自带wandb,根据情况选就好。不想用wandb就选3
之后慢慢等待训练完成
如果在这一步遇到错误,请移步Github issue查看有没有人和你碰到过一样的问题。如果核查确认没有可以试着提新issue。
模型微调源码选读
内容较长,点击展开查看
上面的命令使用deepspeed运行训练脚本llava/train/train_mem.py
,而train_mem.py
实际上只调用了llava/train/train.py
里面的train(attn_implementation="flash_attention_2")
。train
函数做的事情如下:
首先使用transformers.HfArgumentParser
类解析命令行参数,该类的作用是将命令行参数解析为dataclass
对象。dataclass
是Python3.7中引入的一个新特性,通过dataclass
可以方便地定义一个类,并且可以自动实现__init__
、__repr__
等方法。
1 | parser = transformers.HfArgumentParser( |
然后通过parser.parse_args_into_dataclasses()方法解析命令行参数,并将解析结果保存到model_args、data_args和training_args三个变量中。
1 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
训练精度配置与BitsAndBytesConfig
类
接着配置训练精度:
1 | compute_dtype = (torch.float16 if training_args.fp16 |
关于BitsAndBytesConfig
类,这里给出官方文档给大家翻阅。
1 | This is a wrapper class about all possible attributes and features that you can play with a model that has been |
在模型权重加载完成后,还会设置k比特训练:
1 | if training_args.bits in [4, 8]: |
模型权重加载
之后是对模型权重的加载。既然是微调,那就是在已有模型基础上使用数据对模型进行小学习速度的训练。
加载权重的逻辑很简单:
model_args.vision_tower
不为空,'mpt' not in model_args.model_name_or_path
:权重加载进LlavaLlamaForCausalLM
类。这种情况即正常情况model_args.vision_tower
不为空,'mpt' in model_args.model_name_or_path
:手动设置attn_impl
并将权重加载进LlavaMptForCausalLM
类model_args.vision_tower
为空,模型为llama模型,直接将权重加载进LlamaForCausalLM
类
具体代码不放了。
梯度设置
冻结该冻结的,保留需要的:
1 | if model_args.freeze_backbone: # 冻结 |
LoRA
1 | if training_args.lora_enable: # LoRA |
1 | if training_args.bits in [4, 8]: |
之后进行模型其他配置
模型训练
最后使用trainner训练模型
1 | data_module = make_supervised_data_module(tokenizer=tokenizer, |
合并LoRA权重
完成模型训练以后,我们要将LoRA权重与原始模型权重合并:
1 | python scripts/merge_lora_weights.py --model-path "./checkpoints/llava-v1.5-7b-lora" \ |
这样,就能得到可以直接用于推理的模型了,这个模型现在存储在./checkpoints/llava-v1.5-7b-merged
文件夹下。
而所谓合并模型权重,就是先加载一遍base权重,再加载lora权重,最后再将整个模型的权重重新保存。
合并权重源码解读
内容较长,点击展开查看
1 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig |
模型测试
测试模型的性能,会发现微调起了作用:
1 | from llava.eval.run_llava import eval_model |
模型经过微调后,对于我们的训练数据,能得到与标签一致的运行结果:
经过微调的模型输出:
1 | The person in the picture is Nathida, who is a character in the Original God and its derivative works produced by Mihoyo. Her real name is Buyel, the grass god in the "Earthly Seven rulers", and is given the nickname of "Little Lucky Grass King" by the XuMi people, the youngest of the seven gods today. |
而如果不经过微调,模型只会告诉你照片上有个小女孩。