wandb 介紹
-
【wandb官網(wǎng)】
wandb
是Weights & Biases
的縮寫(xiě)(w and b) - 核心作用:
- 可視化重要參數(shù)
- 云端存儲(chǔ)
- 提供各種工具
- 可以和其他工具配合使用,比如下面的
pytorch, HF transformers, tensorflow, keras
等等
- 可以在里面使用
matplotlib
- 貌似是
tensorboard
的上位替代
注冊(cè)賬號(hào)
- 首先我們需要去官網(wǎng)注冊(cè)賬號(hào),貌似不能使用vpn
注冊(cè)號(hào)后,按照教程創(chuàng)建一個(gè)團(tuán)隊(duì),然后來(lái)到這個(gè)界面
可以按照這個(gè)Quickstart
的樣例走一下。選擇Track Runs
,接下來(lái)可以選擇使用哪個(gè)工具訓(xùn)練的模型
然后需要pip install wandb
導(dǎo)包,以及wandb login
登錄
使用 HF Trainer + wandb 訓(xùn)練
- 我們調(diào)用官方給的樣例
我們發(fā)現(xiàn)其實(shí)新添了這幾個(gè)內(nèi)容:WANDB_PROJECT
環(huán)境變量:項(xiàng)目名WANDB_LOG_MODEL
環(huán)境變量:是否保存中繼到wandbWANDB_WATCH
環(huán)境變量 - 在
TrainingArguments
中,設(shè)置了report_to="wandb"
最后調(diào)用wandb.finish()
,整體變化不大
# This script needs these libraries to be installed:
# numpy, transformers, datasets
import wandb
import os
import numpy as np
from datasets import load_dataset
from transformers import TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# 設(shè)置GPU編號(hào)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {"accuracy": np.mean(predictions == labels)}
print("Loading Dataset")
# download prepare the data
dataset = load_dataset("yelp_review_full")
print("Loading Tokenizer")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(300))
small_train_dataset = small_train_dataset.map(tokenize_function, batched=True)
small_eval_dataset = small_train_dataset.map(tokenize_function, batched=True)
print("Loading Model")
# download the model
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=5)
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"]="my-awesome-project"
# save your trained model checkpoint to wandb
os.environ["WANDB_LOG_MODEL"]="true"
# turn off watch to log faster
os.environ["WANDB_WATCH"]="false"
# pass "wandb" to the 'report_to' parameter to turn on wandb logging
training_args = TrainingArguments(
output_dir='models',
report_to="wandb",
logging_steps=5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
evaluation_strategy="steps",
eval_steps=20,
max_steps = 100,
save_steps = 100
)
print("Loading Trainer")
# define the trainer and start training
trainer = Trainer(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics,
)
print("Training")
trainer.train()
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()
- 在 wandb 網(wǎng)站中
我們可以打開(kāi)該 project。每一次運(yùn)行相當(dāng)于一次run
,我這里跑了三次所以就有三條線。
這里主要是看eval
驗(yàn)證集和train
訓(xùn)練集的一些參數(shù)。 - 我們可以刪掉不關(guān)心的面板,或者增添一個(gè)想看的面板
但如果兩個(gè)參數(shù)的值域變化比較大的話,在一個(gè)圖里面比較難看清,所以比較相關(guān)的參數(shù)才建議放在一個(gè)圖里。
低級(jí) API
- 這上面是封裝比較高級(jí)的 API,一般我們也都配合
transformers
庫(kù)去用
如果想用比較原生的 API,一般用法如下:
首先調(diào)用wandb.init()
方法
然后使用wandb.log(dict)
輸出你要可視化的參數(shù)即可。
# train.py
import wandb
import random # for demo script
wandb.login()
epochs = 10
lr = 0.01
run = wandb.init(
# Set the project where this run will be logged
project="my-awesome-project",
# Track hyperparameters and run metadata
config={
"learning_rate": lr,
"epochs": epochs,
},
)
offset = random.random() / 5
print(f"lr: {lr}")
# simulating a training run
for epoch in range(2, epochs):
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
wandb.log({"accuracy": acc, "loss": loss})
# run.log_code()
文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-845933.html
文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-845933.html
到了這里,關(guān)于【Python】科研代碼學(xué)習(xí):十四 wandb (可視化AI工具)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!