Text2SQL微调

Text2SQL微调

Text2SQL相关微调的代码,我们拆分到了DB-GPT-Hub子项目当中,也可以直接查看源码。

Text2SQL微调主要包含以下流程:

  • 搭建环境
  • 数据处理
  • SFT训练
  • 权重合并
  • 模型预测
  • 效果评估

我们推荐使用 来搭建Text2SQL微调的环境

当前项目支持多种基座模型,可以按需下载。本教程中我们以为基座模型,模型可以从HuggingFace、魔搭 等平台下载。以HuggingFace为例, 下载命令为:

本教程案例数据主要以 Spider 数据集为示例 :

  • 简介:Spider 数据集是一个跨域的复杂 text2sql 数据集,包含了自然语言问句和分布在 200 个独立数据库中的多条 SQL,内容覆盖了 138 个不同的领域。
  • 下载:下载数据集到项目目录, 即位于中。

项目使用的是信息匹配生成法进行数据准备,即结合表信息的 SQL + Repository 生成方式,这种方式结合了数据表信息,能够更好地理解数据表的结构和关系,适用于生成符合需求的 SQL 语句。

项目已经将相关处理代码封装在对应脚本中,可以直接一键运行脚本命令,在
目录中将得到生成的训练集

其中训练集中为 8659 条,评估集为 1034 条。生成的训练集中数据格式形如:


中配置训练的数据文件,json文件中对应的 key 的值默认为
,此值即在后续训练脚本
中参数
需要传入的值, json中的file_name 的值为训练集的文件名字。

数据处理代码逻辑

数据处理的核心代码主要在
中,核心处理 class 是
,核心处理函数是

首先将 Spider 数据中的 table 信息处理成为字典格式,key 和 value 分别是 db_id 和该 db_id 对应的 table、columns 信息处理成所需的格式,例如:

然后将上述文本填充于 config 文件中 INSTRUCTION_PROMPT 的 {} 部分,形成最终的 instruction, INSTRUCTION_PROMPT 如下所示:
最后将训练集和验证集中每一个 db_id 对应的 question 和 query 处理成模型 SFT 训练所需的格式,即上面数据处理代码执行部分所示的数据格式。

:::success 说明: 如果你想自己收集更多数据进行训练,可以利用本项目相关代码参照如上逻辑进行处理。

:::

为了简便起见,本复现教程以 LoRA 微调直接运行作为示例,但项目微调不仅能支持 LoRA 还支持 QLoRA 以及 deepspeed
加速。训练脚本
详细参数如下所示:

:::color1 train_sft.sh 中关键参数与含义介绍:

  • model_name_or_path :所用 LLM 模型的路径。
  • dataset :取值为训练数据集的配置名字,对应在 dbgpt_hub/data/dataset_info.json 中外层 key 值,如 example_text2sql。
  • max_source_length :输入模型的文本长度,本教程的效果参数为 2048,为多次实验与分析后的最佳长度。
  • max_target_length :输出模型的 sql 内容长度,设置为 512。
  • template:项目设置的不同模型微调的 lora 部分,对于 Llama2 系列的模型均设置为 llama2。
  • lora_target :LoRA 微调时的网络参数更改部分。
  • finetuning_type : 微调类型,取值为 [ ptuning、lora、freeze、full ] 等。
  • lora_rank : LoRA 微调中的秩大小。
  • loran_alpha: LoRA 微调中的缩放系数。
  • output_dir :SFT 微调时 Peft 模块输出的路径,默认设置在 dbgpt_hub/output/adapter/路径下 。
  • per_device_train_batch_size :每张 gpu 上训练样本的批次,如果计算资源支持,可以设置为更大,默认为 1。
  • gradient_accumulation_steps :梯度更新的累计steps值。
  • lr_scheduler_type :学习率类型。
  • gpt 教程logging_steps :日志保存的 steps 间隔。
  • save_steps :模型保存的 ckpt 的 steps 大小值。
  • num_train_epochs :训练数据的 epoch 数。
  • learning_rate : 学习率,推荐的学习率为 2e-4。

:::


如果想基于 QLoRA 训练,可以在脚本中增加参数 quantization_bit 表示是否量化,取值为 [ 4 或者 8 ],开启量化。
对于想微调不同的 LLM,不同模型对应的关键参数 lora_target 和 template,可以参照项目的 README.md 中相关内容进行更改。
项目目录下
下的
,此文件路径为关于模型预测结果默认输出的位置(如果没有则需建立)。本教程
中的详细参数如下

其中参数
的值为模型预测的结果文件名,结果在
目录下可以找到。
对于模型在数据集上的效果评估,默认为在 spider 数据集上。运行以下命令:

由于大模型生成的结果具有一定的随机性,和 temperature 等参数密切相关(可以在 中的 GeneratingArguments 中进行调整)。在项目默认情况下,我们多次评估的效果,执行准确率均在 0.789 及以上。部分实验和评测结果我们已经放在项目 中,仅供参考。

DB-GPT-Hub 基于 大模型用 LoRA 在 Spider 的训练集上微调后的权重文件已经放出,目前在 spider 的评估集上实现了约为 0.789 的执行准确率,权重文件 在 HuggingFace 上可以找到。

本文实验环境为基于一台带有 A100(40G) 的显卡服务器,总训练时长 12h 左右。如果你的机器资源不够,可以优先考虑缩小参数 gradient_accumulation_steps 的取值,另外可以考虑用 QLoRA 的方式微调(训练脚本
中增加
),从我们的经验看,QLoRA 在 8 个 epoch 时和 LoRA 微调的结果相差不大。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请联系我们举报,一经查实,本站将立刻删除。

发布者:Ai探索者,转载请注明出处:https://javaforall.net/242025.html原文链接:https://javaforall.net

(0)
上一篇 2026年3月15日 下午11:40
下一篇 2026年3月15日 下午11:40


相关推荐

关注全栈程序员社区公众号