(T4 16G)模型預(yù)訓(xùn)練colab腳本在github主頁(yè)面。詳見(jiàn)Finetuning_LLama_2_0_on_Colab_with_1_GPU.ipynb
在上一篇博客提到兩種改進(jìn)預(yù)訓(xùn)練模型性能的方法Retrieval-Augmented Generation (RAG) 或者 finetuning。本篇博客過(guò)一下模型微調(diào)。
微調(diào):這是采用預(yù)訓(xùn)練的LLM并在較小的特定數(shù)據(jù)集上進(jìn)一步訓(xùn)練它以適應(yīng)特定任務(wù)或提高其性能的過(guò)程。通過(guò)微調(diào),我們根據(jù)我們的數(shù)據(jù)調(diào)整模型的權(quán)重,使其更適合我們應(yīng)用程序的獨(dú)特需求。
從Hugging face的開(kāi)源大模型排行榜open_llm_leaderboard可以看到Llama 2是一個(gè)高性能base model,并且其授權(quán)許可寬松,可用于商業(yè)用途的大語(yǔ)言模型,因而本篇以Llma-2的模型微調(diào)為例。
Llama-2 預(yù)訓(xùn)練
從零開(kāi)始訓(xùn)練一個(gè)類似LlaMA 2的預(yù)訓(xùn)練模型需要龐大的數(shù)據(jù)和算力,預(yù)計(jì)的所有花費(fèi)在一億美金左右,這是很多公司和個(gè)人不具備這一經(jīng)濟(jì)條件,因而更容易些的做法是在開(kāi)源預(yù)訓(xùn)練模型的基礎(chǔ)上進(jìn)行微調(diào),這大大降低了數(shù)據(jù)集和算力的需求,作為個(gè)人也是可以實(shí)現(xiàn)的。
模型預(yù)訓(xùn)練colab腳本在github主頁(yè)面。詳見(jiàn)Finetuning_LLama_2_0_on_Colab_with_1_GPU.ipynb
模型量化
為了模型推理速度更快,對(duì)模型進(jìn)行量化是個(gè)不錯(cuò)的選擇,而在微調(diào)的過(guò)程中感知量化微調(diào)可以提升量化模型的性能,本小節(jié)先介紹模型的量化,下一小節(jié)介紹LlaMA-2的感知量化。
內(nèi)存和磁盤需求
由于磁盤上的模型是完全加載到內(nèi)存中再運(yùn)行的,因而內(nèi)存所需的空間和磁盤空間大小事一樣的。
Model | 模型原始大小 | 4比特量化大小 |
---|---|---|
7B | 13GB | 3.9GB |
13B | 24GB | 7.8GB |
30B | 60GB | 19.5GB |
65B | 120GB | 38.5GB |
模型量化借助于github 上Llama2.cpp工程??梢詫?shí)現(xiàn)模型的量化和高效的推理,llama2.cpp官方特性介紹如下:
- Plain C/C++ implementation without dependencies
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
- AVX, AVX2 and AVX512 support for x86 architectures
- Mixed F16 / F32 precision
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support
- CUDA, Metal and OpenCL GPU backend support
量化的方法
量化的方法比較多,命名方法遵循”q” +量化比特位+變種,如下基于Huggingface上TheBloke模型庫(kù)列出了可行的量化方法和他們的使用例子。
- q2_k:用Q4_k對(duì)attention.wv和feed_forward.w2量化,其他用Q2_K量化;
- q3_k_l:用Q5_k對(duì)attention.wv、attention.wo和feed_forward.w2量化,其他用Q2_K量化;
- q3_k_m:用Q4_k對(duì)attention.wv、attention.wo和feed_forward.w2量化,其他用Q2_K量化;
- q3_k_s:用用Q3_K量化所有張量;
- q4_0:原始4比特方法量化;
- q4_l:準(zhǔn)確度介于q4_0和q5_0之間,但是推理速度比q5模型快;
- q4_k_m:使用Q6_K對(duì)attention.wv和feed_forward.w2張量的前一半量化,其他使用Q4_K量化
- q4_k_s:使用Q4_K量化所有張量
- q5_0:更高準(zhǔn)確性,更高資源占用率,更慢的推理速度;
- q5_1:相比q5_0,可能有更高準(zhǔn)確性,更高資源占用率以及更慢的推理速度;
- q5_k_m:使用Q6_K對(duì)attention.wv和feed_forward.w2張量的前一半量化,其他使用Q5_K量化
- q5_k_s:使用Q5_K量化所有張量
- q6_k_s:使用Q8_K量化所有張量
- q8_0:幾乎和半精度浮點(diǎn)float16一樣,資源占用率和速度都很慢,對(duì)大多數(shù)用戶是不推薦的;
上述的wv、wo的意義如下,關(guān)于Llama-2模型的推導(dǎo),可以大語(yǔ)言模型之四-LlaMA-2從模型到應(yīng)用
從眾多的經(jīng)驗(yàn)上看,Q5_K_M是模型表現(xiàn)和資源占用平衡不錯(cuò)的模型,如果可以進(jìn)一步犧牲性能以減少資源的消耗可以考慮Q4_K_M。總的來(lái)說(shuō)K_M版本的量化比K_S版本的性能要好一些。Q2_K和Q3_*的量化版本由于犧牲的性能比較多,所以一半并不推薦。
Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 |
---|---|---|---|---|---|---|---|
7B | perplexity | 5.9066 | 6.1565 | 6.0912 | 5.9862 | 5.9481 | 5.9070 |
7B | file size | 13.0G | 3.5G | 3.9G | 4.3G | 4.7G | 6.7G |
7B | ms/tok @ 4th | 127 | 55 | 54 | 76 | 83 | 72 |
7B | ms/tok @ 8th | 122 | 43 | 45 | 52 | 56 | 67 |
7B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 |
13B | perplexity | 5.2543 | 5.3860 | 5.3608 | 5.2856 | 5.2706 | 5.2548 |
13B | file size | 25.0G | 6.8G | 7.6G | 8.3G | 9.1G | 13G |
13B | ms/tok @ 4th | - | 103 | 105 | 148 | 160 | 131 |
13B | ms/tok @ 8th | - | 73 | 82 | 98 | 105 | 128 |
13B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 |
困惑度-模型質(zhì)量評(píng)估
Perplexity的計(jì)算基于模型對(duì)測(cè)試數(shù)據(jù)集中每個(gè)單詞的預(yù)測(cè)概率,將這些概率取對(duì)數(shù)并取平均值,然后將結(jié)果取負(fù)指數(shù)得到Perplexity值。Perplexity值越低,表示模型對(duì)測(cè)試數(shù)據(jù)集的預(yù)測(cè)能力越好。
上表中的困惑度測(cè)量是針對(duì)wikitext2測(cè)試數(shù)據(jù)集進(jìn)行的,上下文長(zhǎng)度為512。每個(gè)token的時(shí)間是在MacBook M1 Pro 32GB RAM上使用4和8線程測(cè)量的。
# Variables
MODEL_ID = "mlabonne/EvolCodeLlama-7b"
QUANTIZATION_METHODS = ["q4_k_m"]
# Constants
MODEL_NAME = MODEL_ID.split('/')[-1]
GGML_VERSION = "gguf"
# Install llama.cpp
!git clone https://github.com/ggerganov/llama.cpp
!cd llama.cpp && git pull && make clean && LLAMA_CUBLAS=1 make
!pip install -r llama.cpp/requirements.txt
# Download model
!git lfs install
!git clone https://huggingface.co/{MODEL_ID}
# Convert to fp16
fp16 = f"{MODEL_NAME}/{MODEL_NAME.lower()}.{GGML_VERSION}.fp16.bin"
!python llama.cpp/convert.py {MODEL_NAME} --outtype f16 --outfile {fp16}
# Quantize the model for each method in the QUANTIZATION_METHODS list
for method in QUANTIZATION_METHODS:
qtype = f"{MODEL_NAME}/{MODEL_NAME.lower()}.{GGML_VERSION}.{method}.bin"
!./llama.cpp/quantize {fp16} {qtype} {method}
終端輸出如下:
Cloning into 'llama.cpp'...
remote: Enumerating objects: 7959, done.
remote: Counting objects: 100% (30/30), done.
remote: Compressing objects: 100% (22/22), done.
remote: Total 7959 (delta 11), reused 19 (delta 8), pack-reused 7929
Receiving objects: 100% (7959/7959), 7.71 MiB | 15.48 MiB/s, done.
Resolving deltas: 100% (5477/5477), done.
Already up to date.
I llama.cpp build info:
I UNAME_S: Linux
I UNAME_P: x86_64
I UNAME_M: x86_64
I CFLAGS: -I. -O3 -std=c11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS
I CXXFLAGS: -I. -I./common -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -march=native -mtune=native -DGGML_USE_K_QUANTS
I LDFLAGS:
I CC: cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
I CXX: g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Git LFS initialized.
Cloning into 'EvolCodeLlama-7b'...
remote: Enumerating objects: 35, done.
remote: Counting objects: 100% (32/32), done.
remote: Compressing objects: 100% (32/32), done.
remote: Total 35 (delta 8), reused 0 (delta 0), pack-reused 3
Unpacking objects: 100% (35/35), 483.46 KiB | 2.78 MiB/s, done.
- Gguf
GGUF是為GGML推理而提出的存儲(chǔ)模型的文件格式,GGUF是為了能夠快速加載、保存和閱讀模型的二進(jìn)制文件格式,通常由Pytorch或者其他框架訓(xùn)練的模型需要導(dǎo)出為GGUF格式后再由GGML推理使用,GGUF是GGML、GGMF以及GGJT的后繼者。
enum ggml_type {
GGML_TYPE_F32 = 0,
GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3,
// GGML_TYPE_Q4_2 = 4, support has been removed
// GGML_TYPE_Q4_3 (5) support has been removed
GGML_TYPE_Q5_0 = 6,
GGML_TYPE_Q5_1 = 7,
GGML_TYPE_Q8_0 = 8,
GGML_TYPE_Q8_1 = 9,
// k-quantizations
GGML_TYPE_Q2_K = 10,
GGML_TYPE_Q3_K = 11,
GGML_TYPE_Q4_K = 12,
GGML_TYPE_Q5_K = 13,
GGML_TYPE_Q6_K = 14,
GGML_TYPE_Q8_K = 15,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
GGML_TYPE_COUNT,
};
GGUF的具體細(xì)節(jié)參見(jiàn)https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
模型訓(xùn)練流程
安裝環(huán)境—>加載預(yù)訓(xùn)練模型—>微調(diào)模型—>保存模型
當(dāng)然也可以直接使用huggingface開(kāi)發(fā)的模型微調(diào)庫(kù)TRL,這會(huì)更簡(jiǎn)潔。
安裝環(huán)境
!pip install huggingface_hub
!pip install transformers==4.31.0
!pip install accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 trl==0.4.7
!pip install sentencepiece
transformers是大語(yǔ)言模型通用的架構(gòu),peft(Parameter Efficiency Fine-Tuning) 是集成允許先進(jìn)的訓(xùn)練技術(shù),如k-bit量化、低秩(low-rank)逼近和梯度檢查點(diǎn),從而產(chǎn)生更高效和資源友好的模型。
trl是Hugging face提供的強(qiáng)化學(xué)習(xí)庫(kù),本文只是指令微調(diào)模型,并不涉及Reward model和RLHF訓(xùn)練部分。
bitsandbytes是對(duì)CUDA自定義函數(shù)的輕量級(jí)封裝,特別是針對(duì)8位優(yōu)化器、矩陣乘法(LLM.int8())和量化函數(shù)。
加載模型
導(dǎo)入預(yù)訓(xùn)練模型. 使用transformers庫(kù)的AutoTokenizer類和 AutoModelForCausalLM 類自動(dòng)下載和創(chuàng)建模型實(shí)例. The BitsAndBytesConfig類用于模型的量化參數(shù)設(shè)置,比如4-bit是量化位數(shù),torch.bfloat16是微調(diào)時(shí)用的數(shù)據(jù)類型。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# Activate 4-bit precision base model loading
use_4bit = True
# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"
# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"
# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False
bnb_config = BitsAndBytesConfig(
load_in_4bit=use_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=use_nested_quant,
)
model_name = "meta-llama/Llama-2-7b-chat-hf"
#Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# needed for llama tokenizer
tokenizer.pad_token = tokenizer.eos_token
####Below is for mlabonne/guanaco-llama2-1k dataset
#tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
#Load the entire model on the GPU 0
device_map = {"": 0}
#Load base model
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map
)
從peft庫(kù)導(dǎo)入prepare_model_for_kbit_training函數(shù),并使用該函數(shù)進(jìn)行k-bit量化前準(zhǔn)備. gradient_checkpointing_enable() 函數(shù)是能了在訓(xùn)練階段可以降低內(nèi)存使用的梯度 checkpointing特性。
from peft import prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
可訓(xùn)練參數(shù)
print_trainable_parameters函數(shù)用于打印模型可訓(xùn)練參數(shù). 從peft庫(kù)導(dǎo)入 LoraConfig 和 get_peft_model函數(shù)。LoraConfig用于配置縮減訓(xùn)練參數(shù)的LORA (Low Rank Approximation)方法。get_peft_model將LORA方法應(yīng)用于模型. 打印的是模型可訓(xùn)練參數(shù)的情況。
從終端輸出可以看到使用LORA方法后約11%的參數(shù)才會(huì)被微調(diào)時(shí)更新, 這大大降低了內(nèi)存,不同的LORA參數(shù)會(huì)需要不同的內(nèi)存,下圖中的兩種配置,分別對(duì)應(yīng)了訓(xùn)練的時(shí)候需要內(nèi)存情況。
不同的LORA參數(shù)設(shè)置,可訓(xùn)練的參數(shù)量會(huì)有所差異。
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
from peft import LoraConfig, get_peft_model
# LoRA attention dimension 64, 8
lora_r = 8
# Alpha parameter for LoRA scaling 16,32
lora_alpha = 32
# Dropout probability for LoRA layers 0.1 0.05
lora_dropout = 0.1
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["q_proj","v_proj"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)
print_trainable_parameters(model)
該函數(shù)輸出的一個(gè)示例是:
trainable params: 32768 || all params: 139493376 || trainable%: 0.02349072116513977
trainable params: 65536 || all params: 139526144 || trainable%: 0.04697040864255519
trainable params: 98304 || all params: 156336128 || trainable%: 0.06287989939216097
trainable params: 131072 || all params: 156368896 || trainable%: 0.08382229673093043
trainable params: 163840 || all params: 240820224 || trainable%: 0.06803415314487873
trainable params: 196608 || all params: 240852992 || trainable%: 0.08162987653481174
trainable params: 229376 || all params: 257662976 || trainable%: 0.08902171493975138
trainable params: 262144 || all params: 257695744 || trainable%: 0.10172616587722923
trainable params: 294912 || all params: 342147072 || trainable%: 0.086194512282718
trainable params: 327680 || all params: 342179840 || trainable%: 0.09576250897773522
trainable params: 360448 || all params: 358989824 || trainable%: 0.10040618867235634
trainable params: 393216 || all params: 359022592 || trainable%: 0.10952402683338658
trainable params: 425984 || all params: 443473920 || trainable%: 0.09605615590652997
trainable params: 458752 || all params: 443506688 || trainable%: 0.10343744805038882
trainable params: 491520 || all params: 460316672 || trainable%: 0.1067786656226086
trainable params: 524288 || all params: 460349440 || trainable%: 0.11388913604413203
trainable params: 557056 || all params: 544800768 || trainable%: 0.10224948875255624
trainable params: 589824 || all params: 544833536 || trainable%: 0.10825765321465088
trainable params: 622592 || all params: 561643520 || trainable%: 0.11085180863477247
trainable params: 655360 || all params: 561676288 || trainable%: 0.11667930692491686
trainable params: 688128 || all params: 646127616 || trainable%: 0.10650032330455289
trainable params: 720896 || all params: 646160384 || trainable%: 0.11156610925871926
trainable params: 753664 || all params: 662970368 || trainable%: 0.11367989225123257
trainable params: 786432 || all params: 663003136 || trainable%: 0.11861663351167015
trainable params: 819200 || all params: 747454464 || trainable%: 0.10959864974463515
trainable params: 851968 || all params: 747487232 || trainable%: 0.11397759901803915
trainable params: 884736 || all params: 764297216 || trainable%: 0.11575810842676156
trainable params: 917504 || all params: 764329984 || trainable%: 0.1200402992433174
trainable params: 950272 || all params: 848781312 || trainable%: 0.11195722461900763
trainable params: 983040 || all params: 848814080 || trainable%: 0.11581334748829802
...
加載訓(xùn)練數(shù)據(jù)集
from datasets import load_dataset
dataset = load_dataset("Abirate/english_quotes")
dataset = dataset.map(lambda samples: tokenizer(samples["quote"]), batched=True)
Downloading readme: 0%| | 0.00/5.55k [00:00<?, ?B/s]
Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]
Downloading data: 0%| | 0.00/647k [00:00<?, ?B/s]
Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]
Generating train split: 0 examples [00:00, ? examples/s]
Map: 0%| | 0/2508 [00:00<?, ? examples/s]
從Huggingface的datasets庫(kù)導(dǎo)入load_dataset函數(shù), 用其加載"Abirate/english_quotes"數(shù)據(jù)集中的“quotes”字段,然后使用LLaMA tokenizer對(duì)其tokenize化。
定義訓(xùn)練參數(shù)并訓(xùn)練模型
可以使用tranformers和trl庫(kù)兩種方式實(shí)現(xiàn)微調(diào),TRL是huggingface開(kāi)發(fā)的模型微調(diào)庫(kù),旨在簡(jiǎn)化和簡(jiǎn)化語(yǔ)言模型的微調(diào)過(guò)程,憑借其直觀的接口和廣泛的功能,TRL使研究人員和從業(yè)者能夠輕松高效地微調(diào)大型語(yǔ)言模型,如LLaMA-v2-7B。
通過(guò)利用TRL,我們可以釋放語(yǔ)言模型化的全部潛力。它為各種NLP任務(wù)提供了一套全面的工具和技術(shù),包括文本分類、命名實(shí)體識(shí)別、情感分析等等。有了TRL,能夠根據(jù)特定需求微調(diào)LLaMA-v2-7B定制模型的功能。
這里使用了transformers庫(kù)中的Trainer類,使用模型, 訓(xùn)練數(shù)據(jù)集, 以及訓(xùn)練參數(shù)對(duì)Trainer實(shí)例化,訓(xùn)練數(shù)據(jù)集設(shè)置了訓(xùn)練時(shí)的各種參數(shù),比如 batch size, learning rate, and 優(yōu)化算法 (paged_adamw_8bit)。 DataCollatorForLanguageModeling 用于整理和批處理(batch)標(biāo)記化數(shù)據(jù)。 最終調(diào)用trainer.train()方法開(kāi)啟微調(diào)訓(xùn)練。在后文又給了基于trl庫(kù)的更簡(jiǎn)單的接口。
import transformers
################################################################################
# TrainingArguments parameters
################################################################################
# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"
# Number of training epochs
num_train_epochs = 1
# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False
# Batch size per GPU for training
per_device_train_batch_size = 4
# Batch size per GPU for evaluation
per_device_eval_batch_size = 4
# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1
# Enable gradient checkpointing
gradient_checkpointing = True
# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3
# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4
# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001
# Optimizer to use, paged_adamw_8bit paged_adamw_32bit etc...
optim = "paged_adamw_8bit"
# Learning rate schedule
lr_scheduler_type = "cosine"
# Number of training steps (overrides num_train_epochs)
max_steps = -1
# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03
# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True
# Save checkpoint every X updates steps
save_steps = 0
# Log every X updates steps
logging_steps = 25
# Fine-tuned model name
new_model = "llama-2-7b-shichaog"
# Set training parameters
training_arguments = transformers.TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
optim=optim,
save_steps=save_steps,
logging_steps=logging_steps,
learning_rate=learning_rate,
weight_decay=weight_decay,
fp16=fp16,
bf16=bf16,
max_grad_norm=max_grad_norm,
max_steps=max_steps,
warmup_ratio=warmup_ratio,
group_by_length=group_by_length,
lr_scheduler_type=lr_scheduler_type,
report_to="tensorboard"
)
## needed for llama tokenizer
tokenizer.pad_token = tokenizer.eos_token
trainer = transformers.Trainer(
model=model,
train_dataset=dataset["train"],
# args=transformers.TrainingArguments(
# per_device_train_batch_size=1,
# gradient_accumulation_steps=4,
# warmup_steps=2,
# max_steps=10,
# learning_rate=2e-4,
# fp16=True,
# logging_steps=1,
# output_dir="outputs",
# optim="paged_adamw_8bit"
# ),
args=training_arguments,
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
trainer.model.save_pretrained(new_model)
圖右側(cè)顯示了GPU內(nèi)存使用情況
可以使用trl庫(kù)接口實(shí)現(xiàn)上面的功能,這會(huì)比上面更簡(jiǎn)單一些,作用上是一致的。
################################################################################
# SFT parameters
################################################################################
from trl import SFTTrainer
# Maximum sequence length to use
max_seq_length = None
# Pack multiple short examples in the same input sequence to increase efficiency
packing = False
# Load the entire model on the GPU 0
device_map = {"": 0}
# Set supervised fine-tuning parameters from trl library
trainer2 = SFTTrainer(
model=model,
train_dataset=dataset["train"],
peft_config=peft_config,
dataset_text_field="quote",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
packing=packing,
)
# Train model
trainer2.train()
# Save trained model
trainer2.model.save_pretrained(new_model)
文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-693527.html
這段代碼和上一段使用transformers庫(kù)的Trainer是一樣的意義和作用,這里的SFTTrainer是對(duì)上面Trainer的封裝,參數(shù)的意義都是一樣的。因?yàn)閠rl庫(kù)支持了PPO之類的RLHF,所以把SFT也支持了會(huì)使trl庫(kù)更完備一些。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-693527.html
到了這里,關(guān)于大語(yǔ)言模型之七- Llama-2單GPU微調(diào)SFT的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!