尝试上手大模型的有监督微调

news/2024/5/7 23:03:35

总览

这个文章留下微调 Gemma-2b-it 模型的记录。以很简单的、只有一句话的微调为例。

本文最终的目标:问模型 “微调测试”,模型回答 “我学会了”。

准备

加载模型和分词器

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it",cache_dir="/mnt/g/models/cache",token=token,
)model = AutoModelForCausalLM.from_pretrained("/mnt/g/models/gemma-1.1-2b-it",device_map="auto",torch_dtype=torch.bfloat16,
)

训练集

对话训练集格式通常是 json,网上下载的后缀一般为 .jsonl,长这样(来源):

{"conversation": [{"role": "human", "message": "中国移动营销行来发展报告alink"}, {"role": "assistant", "message": "王大姐,打字细心一点"}], "category": null, "data_source": "weibo"}
{"conversation": [{"role": "human", "message": "小马也疯狂------地位之争。"}, {"role": "assistant", "message": "于老师不给劝劝架么告诉他们再挣也不是老大"}], "category": null, "data_source": "weibo"}
{"conversation": [{"role": "human", "message": "那些年,我们一起偷看过的电视。「暴走漫画」"}, {"role": "assistant", "message": "真不愧是这么走出来的爹·······"}], "category": null, "data_source": "weibo"}
···

使用 HuggingFace 的 dataset 库可加载这个文件。

from datasets import load_datasetdataset_fromfile = load_dataset("json",data_files=r"./test_input_dataset.jsonl",split="train",
)

本文只用一个对话的微调进行示例,所以用以下方法准备训练集 chat

chat = {"conversation": [[{"role": "user","content": "微调测试",},{"role": "assistant","content": "我学会了",},],]
}
chat = Dataset.from_dict(chat)

TRL 库

借助 HuggingFace 的 TRL 库进行微调。

pip install trl

接下来介绍 TRL 的两个类,以及一个特殊的函数 formatting_func

SFTTrainer

SFT 是 “有监督微调” 的缩写(Supervised Finetuning)。

SFTTrainer 继承于 transformers.Trainer。借助 SFTTrainer,可以封装一个专用于语言模型有监督微调的类。

DataCollatorForCompletionOnlyLM

借助 DataCollatorForCompletionOnlyLM,可以仅对需要生成的 prompt 训练。即,只对模型生成的 token 部分计算 loss。

其他细节不必深究,只需要知道 SFTTrainer 需要一个 data_collator 对象,将语料转换成适合训练的形式。

response_template = "<start_of_turn>model\n"collator = DataCollatorForCompletionOnlyLM(tokenizer=tokenizer,response_template=response_template,
)

可见,实例化这个 collator 需要传入 tokenizerresponse_template

在 Gemma 中,模型的回答都接在 "<start_of_turn>model\n" 之后,所以传入这个 response_template 告诉 collator 从这里开始标记需要训练的部分。

formatting_func

语料需要先转换成某种字符串,再转换成 token,才能输入到模型。

为了将训练语料正确处理成符合预训练模型规则的字符串,SFTTrainer 需要传入一个处理函数。

def formatting_prompts_func(example):output_texts = []for c in example["conversation"]:text = tokenizer.apply_chat_template(c, tokenize=False) + tokenizer.eos_tokenoutput_texts.append(text)return output_texts

这里取了巧,借助 tokenizer 自带的 chat_template 转换。

TrainingArguments

需要向 SFTTrainer 传入优化器、学习率等参数。

不必多言,看示例代码。更多可选参数请查阅 HuggingFace 文档。

from transformers import TrainingArgumentsargs = TrainingArguments(per_device_train_batch_size=8,num_train_epochs=30,learning_rate=2e-5,optim="adamw_8bit",bf16=True,output_dir="/mnt/z/model_test",report_to=["tensorboard"],logging_steps=1,
)

开始训练

做好一切准备后,就能实例化 SFTrainer 开始训练了。

trainer = SFTTrainer(model,tokenizer=tokenizer,train_dataset=chat,max_seq_length=1024,args=args,formatting_func=formatting_prompts_func,data_collator=collator,dataset_kwargs={"add_special_tokens": False},  # 特殊 token 已经在 formatting_func 加过了
)
trainer.train()

LoRA

借助 peft 库,只需要封装一遍 model 就能应用 LoRA。

from peft import LoraConfig, get_peft_modellora_config = LoraConfig(r=512,lora_alpha=512,target_modules=["q_proj","k_proj","v_proj","o_proj",],
)model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()

接下来向 SFTTrainer 传入这个 model 就行。

测试

我用这段代码测试训练效果:

chat = [{"role": "user","content": "微调测试",},
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=100)
print(tokenizer.decode(outputs[0]))

可以看到效果很明显。

<bos><start_of_turn>user
微调测试<end_of_turn>
<start_of_turn>model
我学会了<end_of_turn>
<eos>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hjln.cn/news/25413.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈,一经查实,立即删除!

相关文章

TODO -蓝桥杯2018年A组-付账问题

0.题目 题目描述 几个人一起出去吃饭是常有的事。但在结帐的时候,常常会出现一些争执。 现在有 \(n\) 个人出去吃饭,他们总共消费了 \(S\) 元。其中第 \(i\) 个人带了 \(a_i\) 元。幸运的是,所有人带的钱的总数是足够付账的,但现在问题来了:每个人分别要出多少钱呢? 为了…

小伙伴:我是专升本,能不写在简历里吗?

大家好,我是树哥。 最近我推出了简历辅导服务(详见:500 块就能获得 10 年的行业经验,太赚了!),有一位同学找我做了简历辅导。 在阅读他的简历的时候,我发现他的学历没有写入学时间和毕业时间,感觉不是很直观,于是让他补全一下。小伙伴回复说:我是专升本的,本科只有…

Jetpack Compose 中如何实现全面屏

看问题本质,设置全面屏,是系统窗口的行为,与 View 和 Compose 有什么关系呢? 所以,原理和传统 View 视图是一样的,甚至 Api 都是一模一样的,不熟悉的可以看我之前的文章。传送门: Android 全面屏体验 那为什么还要写这篇文章呢?主要是在 Compose 中写法上的一些区别,…

4.26文件上传学习

文件上传,绕过,验证,检测一、文件上传 概念:(不赘述转web安全文件上传)[[9.6-9.7基础和过滤方式]] 前置知识:(除解析漏洞)后门代码需要以特定格式后缀解析,不能以图片后缀解析; 知识点 1、文件上传-前端验证 直接修改前端js代码,文件上传格式; 2、黑白名单 3、use…

Python-Flask-migrate安装和使用

在开发过程中,需要修改数据库模型,而且还要在修改之后更新数据库。最直接的方式就是删除旧表,但这样会丢失数据。 更好的解决办法是使用数据库迁移框架,它可以追踪数据库模式的变化,然后把变动应用到数据库中。在Flask中可以使用Flask-Migrate扩展,来实现数据迁移。 环境…

MySQL-07.InnoDB数据存储结构

C-07.InnoDB数据存储结构 1.数据库的存储结构:页索引结构给我们提供了高效的索引方式,不过索引信息以及数据记录都是保存在文件上的,确切说是存储在页结构中。另一方面,索引是在存储引擎中实现的,MySQL服务器上的存储引擎负责对表中数据的读取和写入工作。不同存储引擎中存…

以链表为基础实现链式队列——————遍历、入队、出队

以链表为基础实现链式队列 ​ 如果打算以链表作为基础来实现队列的操作,可以避免内存浪费以及避免内存成片移动,只需要确定队头和队尾即可,一般把链表头部作为队头,可以实现头删,把链表尾部作为队尾,可以实现尾插。​ 需要注意的点:遍历队列需要备份地址 出队需要考虑空…

Windows设置开机自启动项

一、常见软件的开机自启设置大部分安装的软件,在设置中都带有“设置开机自启”的选项,直接在软件本身的设置中勾选相应开关项即可。 二、本身无开机自启的软件或一些绿色便携式的软件 (一)实现原理Windows自带了一个启动文件夹,在此文件夹中的软件都会在开机后进行启动操作…