LLM Basics

  • 基本配置
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main # 源码安装
!pip install -q git+https://github.com/huggingface/peft.git

AutoConfig.from_pretrained("bigscience/bloom-7b1")
AutoConfig # 查看模型配置

list(model.parameters())[0].dtype # 查看模型参数类型

for i, param in enumerate(model.parameters()):
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32) # 增加精度, 训练更稳定

class CastOutputToFloat(nn.Sequential):
def forward(self, x):
return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head) # 增加精度, 训练更稳定

model.gradient_checkpointing_enable() # 减少内存使用
model.enable_input_require_grads() # 会计算模型输入的梯度

from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16, #low rank
lora_alpha=32,
# target_modules=["q_proj", "v_proj"], # if you know
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM" # set this for CLM or Seq2Seq
)
  • 数据处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from datasets import load_dataset
dataset = load_dataset("Abirate/english_quotes")

dataset['train'].to_pandas() # 转成panda格式
dataset['train']['author'][:4]

def merge(row):
row['prediction'] = row['quote'] + ' ->: ' + str(row['tags'])
return row
dataset['train'] = dataset['train'].map(merge) # 构建新行

tokenizer(dataset['train']['prediction'][:4]) # 返回值为input_ids和attention_mask

dataset = dataset.map(lambda samples: tokenizer(samples['prediction']), batched=True) # 由'prediction'得到新行'input_ids'和'attention_mask'

# nvitop

batch = tokenizer("“Training models with PEFT and LoRa is cool” ->: ", return_tensors='pt')
with torch.cuda.amp.autocast():
output_tokens = model.generate(**batch, max_new_tokens=50)
print('\\n\\n', tokenizer.decode(output_tokens[0], skip_special_tokens=True))
  • 混合精度; 不同的layer可以放在不同的device
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# import os
# os.environ['HTTP_PROXY'] = '<http://127.0.0.1:7890>'
# os.environ['HTTPS_PROXY'] = '<http://127.0.0.1:7890>'

# https 协议
!pip install -q git+https://github.com/huggingface/transformers.git
# ssh 协议
!pip install -q git+ssh://git@github.com/huggingface/transformers.git

from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",
load_in_8bit=True, # 混合精度
device_map="auto", # 不同的layer可以放在不同的device
)
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

for i, para in enumerate(model.named_parameters()):
print(f'{i}, \\t {para[1].device} \\t{para[1].dtype}') # 查看parameters的精度及其所在的device

from peft import PeftModel
model = PeftModel.from_pretrained(model, "tloen/alpaca-lora-7b")

from peft import mapping
from peft.utils import other
print('model_type', model.config.model_type)
print(model.peft_config['default'].target_modules)
other.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING # 查看默认的target module
  • 一个alpaca inference的example
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Response:"""

generation_config = GenerationConfig(
temperature=1.5,
# nucleus sampling
top_p=0.8,
num_beams=4,
)

def inference(instruction, input=None):
prompt = generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].cuda()
generation_output = model.generate( # model.generate输出的仍是编码
input_ids=input_ids, # 需要通过tokenizer.decode解码
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=256
)
for s in generation_output.sequences:
output = tokenizer.decode(s)
print("Response:", output.split("### Response:")[1].strip())

inference(input("Instruction: "))
  • torch.cuda.amp的使用: 通过loss scale来提升最大的batch size
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Simple CNN
class CNN(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=5120,
kernel_size=3,
stride=1,
padding=1,
)
# /2, downsampling
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(
in_channels=5120,
out_channels=10240,
kernel_size=3,
stride=1,
padding=1,
)
# (channels*w*h)
# w, h: 取决于初始的 width, height
self.fc1 = nn.Linear(10240 * 7 * 7, num_classes)

def forward(self, x):
x = F.relu(self.conv1(x))
# /2
x = self.pool(x)
x = F.relu(self.conv2(x))
# /2
x = self.pool(x)
# 4d => 2d, (bs, features)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
return x

from torchsummary import summary
model = CNN(in_channels=3)
summary(model, input_size=(3, 224, 224), batch_size=32, device='cpu') # 显示经过不同layer后shape的变化

def train():
for epoch in tqdm(range(num_epochs)):
for batch_idx, (batch_x, batch_y) in tqdm(enumerate(train_loader)):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)

# forward
# logits = model(batch_x)
# loss = criterion(logits, batch_y)
with torch.cuda.amp.autocast():
logits = model(batch_x)
loss = criterion(logits, batch_y)

# backward
optimizer.zero_grad()
# loss.backward()
scaler.scale(loss).backward()

# gradient descent
# optimizer.step()
scaler.step(optimizer)
scaler.update()
  • 一个model.generate的example
1
2
3
4
5
6
7
8
9
10
11
12
13
# pip install bitsandbytes
# pip install transformers
# pip install accelerate

MAX_NEW_TOKENS = 128
ckpt = 'facebook/opt-6.7b'
sample = 'hello, who are you?'
tokenizer = AutoTokenizer.from_pretrained(ckpt)
input_ids = tokenizer(sample, return_tensors='pt').input_ids

model = AutoModelForCausalLM.from_pretrained(ckpt, device_map='auto', load_in_8bit=True)
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  • tokenizer的基本接口; 如何训练tokenizer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import AutoTokenizer
tokenizer_t5 = AutoTokenizer.from_pretrained('t5-base')
def tokenize_str(tokenizer, text): # 返回编码再解码后的list, 可以查看常见词是否被再切分
input_ids = tokenizer(text, add_special_tokens=False)['input_ids']
return [tokenizer.decode(token_id) for token_id in input_ids]

python_code = r'''def say_hello():
print('Hello, World!')

# print hello
say_hello()
'''
tokenizer = AutoTokenizer.from_pretrained('gpt2')
print(tokenizer(python_code)['input_ids'])
print(tokenizer(python_code).tokens()) # 查看切分得到的tokens

tokenizer.backend_tokenizer.normalizer
tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(python_code)

# Unicode character composed of 1-4 bytes
a, e = u'a', u'€'
# 1 bytes
byte = ord(a.encode('utf-8'))
print(f"{a}, {a.encode('utf-8')}, {byte}")
# 3 bytes
# byte = ord(e.encode('utf-8'))
# ord 接受的是一个char
# ord: 字符转整数; 整数转字符
byte = [ord(chr(i)) for i in e.encode('utf-8')]
print(f"{e}, {e.encode('utf-8')}, {byte}")

# training a tokenizer
# 不涉及权重或者反向传播
# tokenizer 的 processing pipeline
# normalization
# pretokenization
# tokenizer model
# postprocesssing
# subword tokenization algorithms (subword: tokens are part of words)
# BPE: byte pair encoding
# 迭代式地添加策略,直到一个 target vocabulary size
# word piece
# unigram
# 迭代式地删除策略,直到一个 target vocabulary size
Untitled
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
bytes_to_unicode_map = bytes_to_unicode()
bytes_to_unicode_map

unicode_to_bytes_map = dict((v, k) for k, v in bytes_to_unicode_map.items())
unicode_to_bytes_map

base_vocab = list(unicode_to_bytes_map.keys())
print(base_vocab[0], base_vocab[-1])

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(python_code)
tokens = sorted(tokenizer.vocab.items(), key=lambda x: len(x[0]), reverse=True)
[tokenizer.convert_tokens_to_string([token]) for token, _ in tokens[:10]]
tokens = sorted(tokenizer.vocab.items(), key=lambda x: x[1], reverse=True)
[tokenizer.convert_tokens_to_string([token]) for token, _ in tokens[:12]]
Description Character Bytes Mapped bytes
Regular characters a and ? 97 and 63 a and ?
A nonprintable control character (carriage return) U+000D 13 č
A space 32 Ġ
A nonbreakable space 160 ł
A newline character 10 Ċ
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# training a tokenizer
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer

dataset = load_dataset('./codeparrot/', split='train', streaming=True)
iter_dataset = iter(dataset)
tokenizer = AutoTokenizer.from_pretrained('gpt2')

from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
bytes_to_unicode_map = bytes_to_unicode()
unicode_to_bytes_map = dict((v, k) for k, v in bytes_to_unicode_map.items())
base_vocab = list(unicode_to_bytes_map.keys())

length = 100000
def batch_iterator(batch_size=1000):
for _ in tqdm(range(0, length, batch_size)):
yield [next(iter_dataset)['content'] for _ in range(batch_size)]

new_tokenizer = tokenizer.train_new_from_iterator(batch_iterator(),
vocab_size=12500,
initial_alphabet=base_vocab)

tokens = sorted(new_tokenizer.vocab.items(), key=lambda x: x[1], reverse=False)
[(t, new_tokenizer.convert_tokens_to_string([t])) for t, _ in tokens[257:280]] # 查看tokens

import keyword
len(keyword.kwlist)
for kw in keyword.kwlist: # 查看哪些词还不在vocab中
if kw not in new_tokenizer.vocab:
print(f'`{kw}` not in the new tokenizer')

# 上传到huggingface
import os
os.environ['HTTP_PROXY'] = '<http://127.0.0.1:7890>'
os.environ['HTTPS_PROXY'] = '<http://127.0.0.1:7890>'
ckpt = 'asdfgh'
org = 'asdfghjkl'
new_tokenizer.push_to_hub(ckpt, organization=org)
  • Dataset and IterableDataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from torch.utils.data import DataLoader, Dataset, IterableDataset
class MyDataset(Dataset):
def __init__(self, m, n):
self.x = np.random.randn(m, n)
self.y = list(range(m))
def __getitem__(self, i):
return self.x[i], self.y[i]
def __len__(self):
return len(self.y)

ds = MyDataset(100, 5)
print(len(ds)) # 100
ds[0] # (array([ 0.06201527, -0.77968078, -0.68125061, -1.77969614, -0.66575581]), 0)

class MyIterableDataset(IterableDataset):
def __init__(self, x, y):
super().__init__()
self.start = x
self.end = y
def __iter__(self):
return iter(range(self.start, self.end))

ds1 = MyIterableDataset(3, 8)
ds2 = MyIterableDataset(9, 15)
ds3 = ds1 + ds2
[i for i in ds1] # [3, 4, 5, 6, 7]
[i for i in ds3] # [3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14]
loader = DataLoader(ds3)
list(loader)
# [tensor([3]),
# tensor([4]),
# tensor([5]),
# tensor([6]),
# tensor([7]),
# tensor([9]),
# tensor([10]),
# tensor([11]),
# tensor([12]),
# tensor([13]),
# tensor([14])]

# infinite dataset
rng = np.random.default_rng()
class InfIterableDataset(IterableDataset):
def __init__(self, n):
super().__init__()
self._n = n
def __iter__(self):
start = 0
while True:
x = np.arange(start, start + self._n)
y = rng.choice([0, 1], size=1, p=[0.4, 0.6])
yield x, y
start += self._n

window = 5
cnt = 0
for (x, y) in InfIterableDataset(window):
print(x, y)
if cnt >= 5:
break
cnt += 1
  • mapping and streaming
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from datasets import load_dataset
data_files = "/media/whaow/datasets/PUBMED_title_abstracts_2019_baseline.jsonl"
# memory mapped
# large_dataset = load_dataset("json", data_files=data_files, split="train")
# tokenizer_dataset2 = large_dataset.map(lambda x: tokenizer(x['text']), batched=True, batch_size=20000)
# streaming
large_dataset_streamed = load_dataset("json", data_files=data_files, split="train", streaming=True)
tokenizer_dataset = large_dataset_streamed.map(lambda x: tokenizer(x['text']))

import psutil
# 当前进程的memory info
print(f'{psutil.Process().memory_info().rss/(1024**2):.2f} MB')

import timeit
code_snippet = '''batch_size = 20000
for idx in tqdm(range(0, len(large_dataset), batch_size)):
_ = large_dataset[idx: idx+batch_size]
''' # 对这对代码执行两次, 并统计运行的平均时间
duration = timeit.timeit(stmt=code_snippet, number=2, globals=globals())

next(iter(tokenizer_dataset)).keys()
list(large_dataset_streamed.take(5))[-1]
large_dataset[4] # 结果与上行一致
next(iter(large_dataset_streamed.skip(1000)))
large_dataset[1000] # 结果与上行一致
  • retain graph and GPU memory occupied
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = torch.tensor(1., requires_grad=True)
y = x**2
y.backward(retain_graph=True)
print(x.grad) # tensor(2.)
y.backward()
print(x.grad) # tensor(4.)

def print_gpu_utilization(): # GPU memory occupied: 21296 MB.
nvmlInit()
total_used = 0
for i in range(torch.cuda.device_count()):
handle = nvmlDeviceGetHandleByIndex(i)
info = nvmlDeviceGetMemoryInfo(handle)
total_used += info.used
print(f"GPU memory occupied: {total_used//1024**2} MB.")
  • sft
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from datasets import load_dataset
train_dataset = load_dataset("tatsu-lab/alpaca", split="train")
train_dataset
# Dataset({
# features: ['instruction', 'input', 'output', 'text'],
# num_rows: 52002
# })
print(train_dataset[0])
print(train_dataset[0]['text'])

# check tokenizer 的 vocab_size 与 model embedding layer 是否一致
print(tokenizer)
print(model.model.embed_tokens)
model.resize_token_embeddings(len(tokenizer))

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
model = prepare_model_for_int8_training(model)
peft_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, peft_config)
training_args = TrainingArguments(
output_dir="xgen-7b-tuned-alpaca-l1",
per_device_train_batch_size=4,
optim="adamw_torch",
logging_steps=10,
learning_rate=2e-4,
fp16=True,
warmup_ratio=0.1,
lr_scheduler_type="linear",
num_train_epochs=1,
save_strategy="epoch",
push_to_hub=False,
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
dataset_text_field="text",
max_seq_length=1024,
tokenizer=tokenizer,
args=training_args,
packing=True,
peft_config=peft_config,
)
trainer.train()
  • gradient checkpoints
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 单机多卡 -> 单机单卡

def print_summary(result):
print(f"Time: {result.metrics['train_runtime']:.2f}")
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
print_gpu_utilization()

from transformers import TrainingArguments, Trainer
default_args = {
"output_dir": "tmp",
"evaluation_strategy": "steps",
"num_train_epochs": 1,
"log_level": "error",
"report_to": "none",
}
# training_args = TrainingArguments(per_device_train_batch_size=4, **default_args)
training_args = TrainingArguments(
per_device_train_batch_size=1, gradient_accumulation_steps=4, gradient_checkpointing=True, **default_args
)
trainer = Trainer(model=model, args=training_args, train_dataset=ds)
result = trainer.train()
print_summary(result)
  • pipeline
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipeline = transformers.pipeline(
"text-generation",
model=model_name,
torch_dtype=torch.float16,
device_map="auto"
)
print(tokenizer)
print(pipeline.model)

sequences = pipeline(
'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\\n',
do_sample=True,
top_k=10,
num_return_sequences=3,
eos_token_id=tokenizer.eos_token_id,
max_length=200,
)
for seq in sequences:
print(seq['generated_text'] + '\\n\\n')
  • trl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
import random
from transformers import GPT2Tokenizer
from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

def setup_seed(seed): # 确保结果可复现
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
setup_seed(1)

model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
print(tokenizer)
print(model.pretrained_model)
print(model.v_head)
# ValueHead(
# (dropout): Dropout(p=0.1, inplace=False)
# (summary): Linear(in_features=768, out_features=1, bias=True)
# (flatten): Flatten(start_dim=1, end_dim=-1)
# )

ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
query_tensor # tensor([[1212, 3329, 314, 1816, 284, 262, 220]], device='cuda:0')

generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=True, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])
response_txt # 'This morning I went to the vernacular and found myself at a bar, cook, with a wife. Buggas together in'
reward = [torch.tensor(1.0, device=model.pretrained_model.device)] # 此处为简化的表示
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

input_ids = torch.cat([query_tensor[0], response_tensor[0]])
base_model_output = model.pretrained_model(input_ids, output_hidden_states=True)
last_hidden_state = base_model_output.hidden_states[-1]
print(last_hidden_state.shape) # torch.Size([34, 768])
lm_logits = base_model_output.logits
print(lm_logits.shape) # torch.Size([34, 50257])
with torch.no_grad():
# (34, 768) * (768, 1) => (34, 1)
value = model.v_head(last_hidden_state).squeeze(-1)
print(value.shape) # torch.Size([34])
  • RMSNorm; Swish/SwiLU/SiLU
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight

x = torch.randn(bs, seq_len, embedding_dim)
rms_norm = RMSNorm(embedding_dim)
x_rms = rms_norm(x)

def sigmoid(x):
return 1/(1 + np.exp(-x))

def swish(x):
return x*sigmoid(x)

import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 120
x = np.arange(-5, 5, .01)
plt.plot(x, swish(x))

x = torch.randn(5)
x/(1+torch.exp(-x))
import torch.nn.functional as F
F.silu(x)
  • cache KV
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

for use_cache in (True, False):
times = []
for _ in range(10): # measuring 10 generations
start = time.time()
model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device),
use_cache=use_cache,
max_new_tokens=1000)
times.append(time.time() - start)
mu = round(np.mean(times), 3)
std = round(np.std(times), 3)
print(f"{'with' if use_cache else 'without'} KV caching: {mu} +- {std} seconds")