I have a long chunk of text that I need to process using a transformer, I would then like to have users ask different questions about it (all questions are independent, they don't relate to each other)
The input always looks like this:
...
< LONG TEXT, always the same >
...
< USER QUESTION, dynamic >
I don't want to re-calculate attentions for the long text every time.
My question is: how can I store Keys and Values (and possibly also embeddings as a bonus) for the large text and then re-use them over time?
- I am open to any framework
- I want to use python
- I want to use Llama3.2 3B model
My code:
So far, I have been experimenting with the first step. Trying to process input text, then pause, and then continue. I did not use huggingface's model(...) method as it was taking up considerably more VRAM then model.generate(...).
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
# load models:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
tokenized = tokenizer(long_text, return_tensors="pt")
with torch.no_grad():
n = 5000
t1 = time.time()
# step 1), calculate KV for 5000 tokens
outputs = model.generate(
input_ids=tokenized['input_ids'].to(device)[:, :n],
attention_mask=torch.ones((1, n), dtype=torch.int).to(device),
temperature=0.7,
max_new_tokens=1,
use_cache=True,
return_dict_in_generate=True
)
t2 = time.time()
past_kv = outputs.past_key_values
# step 2), try to reuse calculation from previous model
outputs = model.generate(
input_ids=outputs.sequences,
attention_mask=torch.ones((1, outputs.sequences.shape[1]), dtype=torch.int).to(device),
past_key_values=past_key_values, # this is not being used, why?
temperature=0.7,
max_new_tokens=1,
use_cache=True,
return_dict_in_generate=True
)
t3 = time.time()
print(t2-t1) # ~3 seconds
print(t3-t2) # also ~3 seconds, should be faster
I have seen tips to only use the first token while keeping the previous KVs like this:
# step 2)
new_token_id = outputs.sequences[:, -1:].to(device) # using only the last token
past_kv = outputs.past_key_values
outputs2 = model.generate( # throws exception
input_ids=new_token_id,
past_key_values=past_kv,
max_new_tokens=1,
use_cache=True,
return_dict_in_generate=True
)
However, this approach throws an exception.
IndexError: index -1 is out of bounds for dimension 0 with size 0