Text2SQL相关微调的代码,我们拆分到了DB-GPT-Hub子项目当中,也可以直接查看源码。
Text2SQL微调主要包含以下流程:
- 搭建环境
- 数据处理
- SFT训练
- 权重合并
- 模型预测
- 效果评估
我们推荐使用 来搭建Text2SQL微调的环境
当前项目支持多种基座模型,可以按需下载。本教程中我们以为基座模型,模型可以从HuggingFace、魔搭 等平台下载。以HuggingFace为例, 下载命令为:
本教程案例数据主要以 Spider 数据集为示例 :
- 简介:Spider 数据集是一个跨域的复杂 text2sql 数据集,包含了自然语言问句和分布在 200 个独立数据库中的多条 SQL,内容覆盖了 138 个不同的领域。
- 下载:下载数据集到项目目录, 即位于中。
项目使用的是信息匹配生成法进行数据准备,即结合表信息的 SQL + Repository 生成方式,这种方式结合了数据表信息,能够更好地理解数据表的结构和关系,适用于生成符合需求的 SQL 语句。
项目已经将相关处理代码封装在对应脚本中,可以直接一键运行脚本命令,在
目录中将得到生成的训练集
和
其中训练集中为 8659 条,评估集为 1034 条。生成的训练集中数据格式形如:
在
中配置训练的数据文件,json文件中对应的 key 的值默认为
,此值即在后续训练脚本
中参数
需要传入的值, json中的file_name 的值为训练集的文件名字。
数据处理代码逻辑
数据处理的核心代码主要在
中,核心处理 class 是
,核心处理函数是
。
首先将 Spider 数据中的 table 信息处理成为字典格式,key 和 value 分别是 db_id 和该 db_id 对应的 table、columns 信息处理成所需的格式,例如:
然后将上述文本填充于 config 文件中 INSTRUCTION_PROMPT 的 {} 部分,形成最终的 instruction, INSTRUCTION_PROMPT 如下所示:
最后将训练集和验证集中每一个 db_id 对应的 question 和 query 处理成模型 SFT 训练所需的格式,即上面数据处理代码执行部分所示的数据格式。
:::success 说明: 如果你想自己收集更多数据进行训练,可以利用本项目相关代码参照如上逻辑进行处理。
:::
为了简便起见,本复现教程以 LoRA 微调直接运行作为示例,但项目微调不仅能支持 LoRA 还支持 QLoRA 以及 deepspeed
加速。训练脚本
详细参数如下所示:
:::color1 train_sft.sh 中关键参数与含义介绍:
- model_name_or_path :所用 LLM 模型的路径。
- dataset :取值为训练数据集的配置名字,对应在 dbgpt_hub/data/dataset_info.json 中外层 key 值,如 example_text2sql。
- max_source_length :输入模型的文本长度,本教程的效果参数为 2048,为多次实验与分析后的最佳长度。
- max_target_length :输出模型的 sql 内容长度,设置为 512。
- template:项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。
- lora_target :LoRA 微调时的网络参数更改部分。
- finetuning_type : 微调类型,取值为 [ ptuning、lora、freeze、full ] 等。
- lora_rank : LoRA 微调中的秩大小。
- loran_alpha: LoRA 微调中的缩放系数。
- output_dir :SFT 微调时 Peft 模块输出的路径,默认设置在 dbgpt_hub/output/adapter/路径下 。
- per_device_train_batch_size :每张 gpu 上训练样本的批次,如果计算资源支持,可以设置为更大,默认为 1。
- gradient_accumulation_steps :梯度更新的累计steps值。
- lr_scheduler_type :学习率类型。
- gpt 教程logging_steps :日志保存的 steps 间隔。
- save_steps :模型保存的 ckpt 的 steps 大小值。
- num_train_epochs :训练数据的 epoch 数。
- learning_rate : 学习率,推荐的学习率为 2e-4。
:::
如果想基于 QLoRA 训练,可以在脚本中增加参数 quantization_bit 表示是否量化,取值为 [ 4 或者 8 ],开启量化。
对于想微调不同的 LLM,不同模型对应的关键参数 lora_target 和 template,可以参照项目的 README.md 中相关内容进行更改。
项目目录下
下的
,此文件路径为关于模型预测结果默认输出的位置(如果没有则需建立)。本教程
中的详细参数如下
:
其中参数
的值为模型预测的结果文件名,结果在
目录下可以找到。
对于模型在数据集上的效果评估,默认为在 spider 数据集上。运行以下命令:
由于大模型生成的结果具有一定的随机性,和 temperature 等参数密切相关(可以在 中的 GeneratingArguments 中进行调整)。在项目默认情况下,我们多次评估的效果,执行准确率均在 0.789 及以上。部分实验和评测结果我们已经放在项目 中,仅供参考。
DB-GPT-Hub 基于 大模型用 LoRA 在 Spider 的训练集上微调后的权重文件已经放出,目前在 spider 的评估集上实现了约为 0.789 的执行准确率,权重文件 在 HuggingFace 上可以找到。
本文实验环境为基于一台带有 A100(40G) 的显卡服务器,总训练时长 12h 左右。如果你的机器资源不够,可以优先考虑缩小参数 gradient_accumulation_steps 的取值,另外可以考虑用 QLoRA 的方式微调(训练脚本
中增加
),从我们的经验看,QLoRA 在 8 个 epoch 时和 LoRA 微调的结果相差不大。
发布者:Ai探索者,转载请注明出处:https://javaforall.net/242025.html原文链接:https://javaforall.net
