Understanding LLM / Transformers (You cannot run the code without saving a copy)
Check the status of your GPU
!nvidia-smi
Tue Apr 22 13:09:02 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 41C P8 11W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Installing transformers for further usage (please do not alter the version for stable usage of model)
!pip install transformers==4.47.0
Collecting transformers==4.47.0
Downloading transformers-4.47.0-py3-none-any.whl.metadata (43 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m43.5/43.5 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (3.18.0)
Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (0.30.2)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (2.0.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (24.2)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (6.0.2)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (2024.11.6)
Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (2.32.3)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (0.21.1)
Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (0.5.3)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers==4.47.0) (4.67.1)
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.47.0) (2025.3.2)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.47.0) (4.13.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.47.0) (3.4.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.47.0) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.47.0) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.47.0) (2025.1.31)
Downloading transformers-4.47.0-py3-none-any.whl (10.1 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m10.1/10.1 MB[0m [31m46.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
Attempting uninstall: transformers
Found existing installation: transformers 4.51.3
Uninstalling transformers-4.51.3:
Successfully uninstalled transformers-4.51.3
Successfully installed transformers-4.47.0
Huggingface login
You need the huggingface token (hf_token) to login to huggingface and install the gemma model. Therefore make sure you create your huggingface token. (Described in the Google slides)
######################## TODO (Pre-requisites) ########################
# replace `your_hf_token` with your huggingface token
from huggingface_hub import login
login("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")
#######################################################################
Download the model
Gemma Model: https://huggingface.co/google/gemma-2-2b-it
Please accept the lincense to download the gemma model (As described on Google Slides)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "unsloth/gemma-2-2b-it"
dtype = torch.float16
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="cuda",
torch_dtype=dtype,
)
/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
tokenizer_config.json: 0%| | 0.00/47.0k [00:00<?, ?B/s]
tokenizer.model: 0%| | 0.00/4.24M [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/17.5M [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/636 [00:00<?, ?B/s]
config.json: 0%| | 0.00/913 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/5.23G [00:00<?, ?B/s]
generation_config.json: 0%| | 0.00/209 [00:00<?, ?B/s]
Q1: Chat template Comparison
Evaluation Model: https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
SCORING_MODEL = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-6-v2')
SCORING_TOKENIZER = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L-6-v2')
def calculate_coherence(question, answer, scoring_model=SCORING_MODEL, tokenizer=SCORING_TOKENIZER):
features = tokenizer([question], [answer], padding=True, truncation=True, return_tensors="pt")
scoring_model.eval()
with torch.no_grad():
scores = scoring_model(**features).logits.squeeze().item()
return scores
config.json: 0%| | 0.00/794 [00:00<?, ?B/s]
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
WARNING:huggingface_hub.file_download:Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
model.safetensors: 0%| | 0.00/90.9M [00:00<?, ?B/s]
tokenizer_config.json: 0%| | 0.00/1.33k [00:00<?, ?B/s]
vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/711k [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/132 [00:00<?, ?B/s]
Observe whether the chat template affects the model’s output results.
def generate_text_from_prompt(prompt, tokenizer, model):
"""
generate the output from the prompt.
param:
prompt (str): the prompt inputted to the model
tokenizer : the tokenizer that is used to encode / decode the input / output
model : the model that is used to generate the output
return:
the response of the model
"""
print("========== Prompt inputted to the model ==========\n", prompt)
# Tokenize the prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
######################## TODO (Q1.1 ~ 1.4) ########################
### You can refer to https://huggingface.co/google/gemma-2-2b-it for basic usage
### Make sure to use 'do_sample=False' to get a deterministic response
### Otherwise the coherence score may be different from the sample answer
# Generate response
output_ids = model.generate(input_ids, max_new_tokens=512, do_sample=False)
###################################################################
if output_ids is not None and len(output_ids) > 0:
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
return "Empty Response"
# With chat template
question = "Please tell me about the key differences between supervised learning and unsupervised learning. Answer in 200 words."
chat = [
{"role": "user", "content": question},
]
prompt_with_template = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
response_with_template = generate_text_from_prompt(prompt_with_template, tokenizer, model)
# extract the real output from the model
response_with_template = response_with_template.split('model\n')[-1].strip('\n').strip()
print("========== Output ==========\n", response_with_template)
score = calculate_coherence(question, response_with_template)
print(f"========== Coherence Score : {score:.4f} ==========")
========== Prompt inputted to the model ==========
<bos><start_of_turn>user
Please tell me about the key differences between supervised learning and unsupervised learning. Answer in 200 words.<end_of_turn>
<start_of_turn>model
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
========== Output ==========
Supervised and unsupervised learning are two fundamental types of machine learning.
**Supervised learning** involves training a model on labeled data, where each input has a corresponding output. The model learns to map inputs to outputs, making predictions on new, unseen data. Think of it like teaching a child with labeled examples: you show them pictures of cats and dogs, and they learn to distinguish between them.
**Unsupervised learning**, on the other hand, uses unlabeled data. The model learns patterns and structures within the data without explicit guidance. It's like letting a child explore a room full of toys and discover patterns on their own. Examples include clustering similar items or finding hidden relationships in data.
**Key differences:**
* **Labeling:** Supervised uses labeled data, unsupervised doesn't.
* **Goal:** Supervised aims to predict outputs, unsupervised aims to discover patterns.
* **Applications:** Supervised: classification, regression, supervised tasks; Unsupervised: clustering, anomaly detection, dimensionality reduction.
Both types are powerful tools with their own strengths and weaknesses, and the choice depends on the specific problem and available data.
========== Coherence Score : 6.0734 ==========
# Without chat template (directly using plain text)
response_without_template = generate_text_from_prompt(question, tokenizer, model)
# extract the real output from the model
response_without_template = response_without_template.split(question.split(' ')[-1])[-1].strip('\n').strip()
print("========== Output ==========\n", response_without_template)
score = calculate_coherence(question, response_without_template)
print(f"========== Coherence Score : {score:.4f} ==========")
========== Prompt inputted to the model ==========
Please tell me about the key differences between supervised learning and unsupervised learning. Answer in 200 words.
========== Output ==========
**Supervised Learning:**
* **Labeled data:** Uses data with known outputs (labels) to train models.
* **Goal:** Predict the output for new, unseen data.
* **Examples:** Image classification, spam detection, predicting house prices.
**Unsupervised Learning:**
* **Unlabeled data:** Uses data without known outputs to discover patterns.
* **Goal:** Explore data, identify clusters, or reduce dimensionality.
* **Examples:** Customer segmentation, anomaly detection, dimensionality reduction.
In essence, supervised learning learns from labeled examples to make predictions, while unsupervised learning explores unlabeled data to uncover hidden structures and relationships.
========== Coherence Score : 4.2210 ==========
Q2: Multi-turn conversations
import matplotlib.pyplot as plt
import seaborn as sns
chat_history = []
round = 0
print("Chatbot: Hello! How can I assist you today? (Type 'exit' to quit)")
while True:
user_input = input("You: ")
if user_input.lower() == "exit":
print("Chatbot: Goodbye!")
break
round += 1
chat_history.append({"role": "user", "content": user_input})
chat_template_format_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
######################## (Q2.1 ~ 2.3) ########################
# Observe the prompt with chat template format that was inputted to the model in the current round to answer Q2.1 ~ Q2.3.
print(f"=== Prompt with chat template format inputted to the model on round {round} ===\n{chat_template_format_prompt}")
print(f"===============================================")
###################################################################
inputs = tokenizer(chat_template_format_prompt, return_tensors="pt").to("cuda")
# Get logits instead of directly generating
with torch.no_grad():
outputs_p = model(**inputs)
logits = outputs_p.logits # Logits of the model (raw scores before softmax)
last_token_logits = logits[:, -1, :] # Take the logits of the last generated token
# Apply softmax to get probabilities
probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Get top-k tokens (e.g., 10)
top_k = 10
top_probs, top_indices = torch.topk(probs, top_k)
# Convert to numpy for plotting
top_probs = top_probs.cpu().squeeze().numpy()
top_indices = top_indices.cpu().squeeze().numpy()
top_tokens = [tokenizer.decode([idx]) for idx in top_indices]
# Plot probability distribution
plt.figure(figsize=(10, 5))
sns.barplot(x=top_probs, y=top_tokens, palette="coolwarm")
plt.xlabel("Probability")
plt.ylabel("Token")
plt.title("Top Token Probabilities for Next Word")
plt.show()
# Generate response
outputs = model.generate(**inputs, max_new_tokens=200, pad_token_id=tokenizer.eos_token_id, do_sample=False)
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(f"Chatbot: {response}")
chat_history.append({"role": "assistant", "content": response})
Chatbot: Hello! How can I assist you today? (Type 'exit' to quit)
You: hello
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
=== Prompt with chat template format inputted to the model on round 1 ===
<bos><start_of_turn>user
hello<end_of_turn>
<start_of_turn>model
===============================================
<ipython-input-11-262378176447>:44: FutureWarning:
Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.
sns.barplot(x=top_probs, y=top_tokens, palette="coolwarm")
/usr/local/lib/python3.11/dist-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 128075 (\N{WAVING HAND SIGN}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/usr/local/lib/python3.11/dist-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20320 (\N{CJK UNIFIED IDEOGRAPH-4F60}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/usr/local/lib/python3.11/dist-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 22909 (\N{CJK UNIFIED IDEOGRAPH-597D}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
Chatbot: Hello! đ
How can I help you today? đ
You: ćäšĺŚäš
=== Prompt with chat template format inputted to the model on round 2 ===
<bos><start_of_turn>user
hello<end_of_turn>
<start_of_turn>model
Hello! đ
How can I help you today? đ<end_of_turn>
<start_of_turn>user
ćäšĺŚäš <end_of_turn>
<start_of_turn>model
===============================================
<ipython-input-11-262378176447>:44: FutureWarning:
Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.
sns.barplot(x=top_probs, y=top_tokens, palette="coolwarm")
/usr/local/lib/python3.11/dist-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 23398 (\N{CJK UNIFIED IDEOGRAPH-5B66}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
/usr/local/lib/python3.11/dist-packages/IPython/core/pylabtools.py:151: UserWarning: Glyph 20064 (\N{CJK UNIFIED IDEOGRAPH-4E60}) missing from font(s) DejaVu Sans.
fig.canvas.print_figure(bytes_io, **kw)
Chatbot: That's a great question! "How to learn" is a very broad topic. To give you the best advice, I need a little more information.
Could you tell me:
* **What do you want to learn?** (e.g., a new language, a specific skill, a new subject, etc.)
* **What is your learning style?** (e.g., visual, auditory, hands-on, etc.)
* **How much time do you have to dedicate to learning?**
* **What resources do you have access to?** (e.g., books, online courses, tutors, etc.)
Once I have a better understanding of your needs, I can give you some personalized recommendations. đđťđĄ
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-11-262378176447> in <cell line: 0>()
6
7 while True:
----> 8 user_input = input("You: ")
9 if user_input.lower() == "exit":
10 print("Chatbot: Goodbye!")
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py in raw_input(self, prompt)
1175 "raw_input was called, but this frontend does not support input requests."
1176 )
-> 1177 return self._input_request(
1178 str(prompt),
1179 self._parent_ident["shell"],
/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py in _input_request(self, prompt, ident, parent, password)
1217 except KeyboardInterrupt:
1218 # re-raise KeyboardInterrupt, to truncate traceback
-> 1219 raise KeyboardInterrupt("Interrupted by user") from None
1220 except Exception:
1221 self.log.warning("Invalid Message:", exc_info=True)
KeyboardInterrupt: Interrupted by user
Q3: Tokenization of Sentence
sentence = "I love taking a Machine Learning course by Professor Hung-yi Lee, What about you?" #@param {type:"string"}
######################## TODO (Q3.1 ~ 3.4) ########################
### You can refer to https://huggingface.co/learn/nlp-course/en/chapter2/4?fw=pt for basic tokenizer usage
### and https://huggingface.co/docs/transformers/en/main_classes/tokenizer for full tokenizer usage
# Encode the sentence into token IDs without adding special tokens
token_ids = tokenizer.encode(sentence, add_special_tokens=False)
# Convert the token IDs back to their corresponding tokens (words or subwords)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
###################################################################
# Iterate through the tokens and their corresponding token IDs
for t, t_id in zip(tokens, token_ids):
# Print the token and its index (ID)
print(f"Token: {t}, token index: {t_id}")
Token: I, token index: 235285
Token: âlove, token index: 2182
Token: âtaking, token index: 4998
Token: âa, token index: 476
Token: âMachine, token index: 13403
Token: âLearning, token index: 14715
Token: âcourse, token index: 3205
Token: âby, token index: 731
Token: âProfessor, token index: 11325
Token: âHung, token index: 18809
Token: -, token index: 235290
Token: yi, token index: 12636
Token: âLee, token index: 9201
Token: ,, token index: 235269
Token: âWhat, token index: 2439
Token: âabout, token index: 1105
Token: âyou, token index: 692
Token: ?, token index: 235336
Q4: Auto-regressive generation
from tqdm import trange
from transformers import HybridCache
max_generation_tokens = 30
######################## TODO (Q4.3 ~ 4.6) ########################
# Modify the value of k and p accordingly
top_k = 2 # Set K for top-k sampling
top_p = 0.6 # Set P for nucleus sampling
###################################################################
# Input prompt
prompt = f"Generate a paraphrase of the sentence 'Professor Hung-yi Lee is one of the best teachers in the domain of machine learning'. Just response with one sentence."
input_ids = tokenizer(prompt, return_tensors="pt")
# Initialize KV Cache
kv_cache = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generation_tokens, device="cuda", dtype=torch.float16)
next_token_id = input_ids.input_ids.to("cuda")
attention_mask = input_ids.attention_mask.to("cuda")
cache_position = torch.arange(attention_mask.shape[1], device="cuda")
generated_sentences_top_k = []
generated_sentences_top_p = []
# Define the generation parameters
generation_params = {
"do_sample": True, # Enable sampling
"max_length": max_generation_tokens + len(input_ids.input_ids[0]), # Total length including prompt
"pad_token_id": tokenizer.pad_token_id, # Ensure padding token is set
"eos_token_id": tokenizer.eos_token_id, # Ensure EOS token is set
"bos_token_id": tokenizer.bos_token_id, # Ensure BOS token is set
"attention_mask": input_ids.attention_mask.to("cuda"), # Move attention mask to GPU
"use_cache": True, # Enable caching
"return_dict_in_generate": True, # Return generation outputs
"output_scores": False, # Disable outputting scores
}
for method in ["top-k", "top-p"]:
for _ in trange(20):
if method == "top-k":
# Generate text using the model with top_k
generated_output = model.generate(
input_ids=input_ids.input_ids.to("cuda"),
top_k=top_k,
**generation_params
)
elif method == "top-p":
# Generate text using the model with top_p
######################## TODO (Q4.3 ~ 4.6) ########################
# Generate output from the model based on the input_ids and specified generation parameters
# You can refer to this documentation: https://huggingface.co/docs/transformers/en/main_classes/text_generation
# Hint: You can check how we generate the text with top_k
generated_output = model.generate(input_ids=input_ids.input_ids.to("cuda"),
top_p=top_p,
**generation_params)###################################################################
else:
raise NotImplementedError()
# Decode the generated tokens
generated_tokens = generated_output.sequences[0, len(input_ids.input_ids[0]):]
decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Combine the prompt with the generated text
sentence = decoded_text.replace(" ,", ",").replace(" 's", "'s").replace(" .", ".").strip()
# Append the generated sentence to the appropriate list
if method == "top-k":
generated_sentences_top_k.append(sentence)
else:
generated_sentences_top_p.append(sentence)
# Print results
print("===== Top-K Sampling Output =====")
print()
for idx,sentence in enumerate(generated_sentences_top_k):
print(f"{idx}. {sentence}")
print()
print("===== Top-P Sampling Output =====")
print()
for idx,sentence in enumerate(generated_sentences_top_p):
print(f"{idx}. {sentence}")
print()
100%|ââââââââââ| 20/20 [00:29<00:00, 1.46s/it]
100%|ââââââââââ| 20/20 [00:28<00:00, 1.43s/it]
===== Top-K Sampling Output =====
0. Professor Lee is highly regarded as a leading expert in machine learning education.
1. Professor Lee is highly respected and renowned in the field of machine learning, making him one of the top teachers in the area.
2. Professor Lee is highly regarded as an expert in machine learning and is known for his exceptional teaching skills.
3. He is highly regarded as an expert in the field of machine learning and is known for being an outstanding teacher.
4. Professor Hung-yi Lee is highly regarded for his expertise and contributions in the field of machine learning.
5. He is highly regarded for his expertise in machine learning and consistently delivers excellent instruction.
6. Professor Lee is highly regarded for his exceptional expertise and contributions to machine learning education.
7. Professor Lee is highly regarded for his expertise in machine learning and consistently ranks among the top instructors in the field.
8. Professor Lee is highly regarded for his expertise in machine learning and is considered a top-tier educator in the field.
9. Professor Lee is highly regarded as a leading expert in machine learning, known for his exceptional teaching abilities.
10. Professor Lee is highly respected and renowned for his expertise in machine learning education.
11. Professor Lee is highly regarded as a leading expert in machine learning education.
12. Professor Lee is highly regarded as an expert in machine learning and is known for being a top-tier teacher in the field.
13. Professor Lee is highly regarded as a top expert in machine learning education.
14. He is highly regarded for his expertise and contributions to the field of machine learning, making him a leading teacher.
15. Professor Lee's exceptional expertise in machine learning has made him a highly regarded and respected figure in the field.
16. Professor Lee is highly regarded as a leading expert in machine learning education.
17. Professor Lee is highly regarded as a leading expert in machine learning education.
18. Professor Lee is highly regarded as a leading expert in machine learning, renowned for his exceptional teaching abilities.
19. He is highly regarded as an expert in machine learning education.
===== Top-P Sampling Output =====
0. Professor Lee is highly regarded for his expertise in machine learning and is considered a top teacher in the field.
1. Professor Lee is highly regarded as a leading expert in machine learning education.
2. Professor Lee is highly regarded for his expertise in machine learning, making him a top teacher in the field.
3. Professor Lee is highly regarded for his expertise in machine learning, making him a top teacher in the field.
4. Professor Lee is highly regarded for his expertise in machine learning and is known for being one of the top educators in the field.
5. Professor Lee is highly regarded as a leading expert in machine learning education.
6. Professor Lee is highly regarded as a leading expert in machine learning education.
7. Professor Lee is highly regarded for his expertise in machine learning and is considered one of the leading educators in the field.
8. Professor Lee is highly regarded for his expertise in machine learning and is known for his exceptional teaching abilities.
9. Professor Lee is highly regarded as a leading expert in machine learning education.
10. Professor Lee is highly regarded as a leading expert in machine learning education.
11. Professor Lee is highly regarded as a leading expert in machine learning education.
12. Professor Lee is highly regarded as a leading expert in machine learning education.
13. Professor Lee is highly regarded for his expertise in machine learning, making him one of the top educators in the field.
14. Professor Lee is highly regarded for his expertise in machine learning, making him a top teacher in the field.
15. Professor Lee is highly regarded as a top expert in machine learning education.
16. Professor Lee is highly regarded as a leading expert in machine learning education.
17. Professor Lee is highly regarded as a top expert in machine learning education.
18. Professor Lee is highly regarded as a leading expert in machine learning education.
19. Professor Lee is highly regarded as a leading expert in machine learning education.
from nltk.translate.bleu_score import sentence_bleu
def compute_self_bleu(generated_sentences):
total_bleu_score = 0
num_sentences = len(generated_sentences)
for i, hypothesis in enumerate(generated_sentences):
references = [generated_sentences[j] for j in range(num_sentences) if j != i]
bleu_scores = [sentence_bleu([ref.split()], hypothesis.split()) for ref in references]
total_bleu_score += sum(bleu_scores) / len(bleu_scores)
return total_bleu_score / num_sentences
# Calculate BLEU score
bleu_score = compute_self_bleu(generated_sentences_top_k)
print(f"self-BLEU Score for top_k (k={top_k}): {bleu_score:.4f}")
# Calculate BLEU score
bleu_score = compute_self_bleu(generated_sentences_top_p)
print(f"self-BLEU Score for top_p (p={top_p}): {bleu_score:.4f}")
self-BLEU Score for top_k (k=2): 0.2623
self-BLEU Score for top_p (p=0.6): 0.5228
/usr/local/lib/python3.11/dist-packages/nltk/translate/bleu_score.py:577: UserWarning:
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
warnings.warn(_msg)
/usr/local/lib/python3.11/dist-packages/nltk/translate/bleu_score.py:577: UserWarning:
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
warnings.warn(_msg)
Q5: t-SNE
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
######################## (Q5.2 ~ 5.3) ########################
# Sentences with different meanings of words
sentences = [
"I ate a fresh apple.", # Apple (fruit)
"Apple released the new iPhone.", # Apple (company)
"I peeled an orange and ate it.", # Orange (fruit)
"The Orange network has great coverage.", # Orange (telecom)
"Microsoft announced a new update.", # Microsoft (company)
"Banana is my favorite fruit.", # Banana (fruit)
]
# Tokenize and move to device
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
inputs = inputs.to(device)
# Get hidden states
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1] # Extract last layer embeddings
# Compute sentence-level embeddings (mean pooling)
sentence_embeddings = hidden_states.mean(dim=1).cpu().numpy()
# Words to visualize
word_labels = [
"Apple (fruit)", "Apple (company)",
"Orange (fruit)", "Orange (telecom)",
"Microsoft (company)", "Banana (fruit)"
]
# Reduce to 2D using t-SNE
tsne = TSNE(n_components=2, perplexity=2, random_state=42)
embeddings_2d = tsne.fit_transform(sentence_embeddings)
# Plot the embeddings
plt.figure(figsize=(8, 6))
colors = ["red", "blue", "orange", "purple", "green", "brown"]
for i, label in enumerate(word_labels):
plt.scatter(embeddings_2d[i, 0], embeddings_2d[i, 1], color=colors[i], s=100)
plt.text(embeddings_2d[i, 0] + 0.1, embeddings_2d[i, 1] + 0.1, label, fontsize=12, color=colors[i])
plt.xlabel("t-SNE Dim 1")
plt.ylabel("t-SNE Dim 2")
plt.title("t-SNE Visualization of Word Embeddings")
plt.show()
##################################################
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Q6: Observe the Attention Weight
# Import necessary libraries
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import trange
from sklearn.decomposition import PCA
# Input prompt for text generation
prompt = "Google "
input_ids = tokenizer(prompt, return_tensors="pt") # Tokenize the input prompt
next_token_id = input_ids.input_ids.to("cuda") # Move input token ids to GPU
attention_mask = input_ids.attention_mask.to("cuda") # Move attention mask to GPU
cache_position = torch.arange(attention_mask.shape[1], device="cuda") # Position for the KV cache
# Set the number of tokens to generate and other parameters
generation_tokens = 20 # Limit for visualization (number of tokens to generate)
total_tokens = generation_tokens + next_token_id.size(1) - 1 # Total tokens to handle
layer_idx = 10 # Specify the layer index for attention visualization
head_idx = 7 # Specify the attention head index to visualize
# KV cache setup for caching key/values across time steps
from transformers.cache_utils import HybridCache
kv_cache = HybridCache(config=model.config, max_batch_size=1, max_cache_len=total_tokens, device="cuda", dtype=torch.float16)
generated_tokens = [] # List to store generated tokens
attentions = None # Placeholder to store attention weights
num_new_tokens = 0 # Counter for the number of new tokens generated
model.eval() # Set the model to evaluation mode
# Generate tokens and collect attention weights for visualization
for num_new_tokens in range(generation_tokens):
with torch.no_grad(): # Disable gradients during inference for efficiency
# Pass the input through the model to get the next token prediction and attention weights
outputs = model(
next_token_id,
attention_mask=attention_mask,
cache_position=cache_position,
use_cache=True, # Use the KV cache for efficiency
past_key_values=kv_cache, # Provide the cached key-value pairs for fast inference
output_attentions=True # Enable the extraction of attention weights
)
######################## TODO (Q6.1 ~ 6.4) ########################
### You can refer to https://huggingface.co/docs/transformers/en/main_classes/output#transformers.modeling_outputs.BaseModelOutput.attentions to see the structure of model output attentions
# Get the logits for the last generated token from outputs
logits = outputs.logits[:, -1, :]
# Extract the attention scores from the model's outputs
attention_scores = outputs.attentions
###################################################################
# Extract attention weights for the specified layer and head
last_layer_attention = attention_scores[layer_idx][0][head_idx].detach().cpu().numpy()
# If it's the first generated token, initialize the attentions array
if num_new_tokens == 0:
attentions = last_layer_attention
else:
# Append the current attention weights to the existing array
attentions = np.append(attentions, last_layer_attention, axis=0)
# Choose the next token to generate based on the highest probability (logits)
next_token_id = logits.argmax(dim=-1)
generated_tokens.append(next_token_id.item()) # Add the token ID to the generated tokens list
# Update the attention mask and next token ID for the next iteration
attention_mask = torch.cat([attention_mask, torch.ones(1, 1, device="cuda")], dim=-1) # Add a new attention mask for the generated token
next_token_id = next_token_id.unsqueeze(0) # Convert the token ID to the required shape
# Update the KV cache with the new past key-values
kv_cache = outputs.past_key_values
cache_position = cache_position[-1:] + 1 # Update the cache position for the next iteration
# Decode the generated tokens into human-readable text
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
full_text = prompt + generated_text # Combine the prompt with the generated text
# Tokenize all the generated text (prompt + generated)
tokens = tokenizer.tokenize(full_text)
# Function to plot a heatmap of attention weights
def plot_attention(attn_matrix, tokens, title="Attention Heatmap"):
plt.figure(figsize=(10, 8)) # Set the figure size
sns.heatmap(attn_matrix, xticklabels=tokens, yticklabels=tokens, cmap="viridis", annot=False) # Plot the attention matrix as a heatmap
plt.xlabel("Key Tokens")
plt.ylabel("Query Tokens")
plt.title(title)
plt.xticks(rotation=45) # Rotate x-axis labels for better visibility
plt.yticks(rotation=0) # Rotate y-axis labels
plt.show()
# Plot the attention heatmap for the last generated token
plot_attention(attentions, tokens, title=f"Attention Weights for Generated Token of Layer {layer_idx}")
Q7: Observe the Activation Scores
The following code is referred from official Gemma tutorials: Gemma Tutorial From Scratch and SAELens
!pip install sae-lens
Collecting sae-lens
Downloading sae_lens-5.9.1-py3-none-any.whl.metadata (5.3 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
Downloading automated_interpretability-0.0.9-py3-none-any.whl.metadata (822 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (3.10.0)
Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (0.1.7)
Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (3.9.1)
Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (5.24.1)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
Downloading pytest_profiling-1.8.1-py3-none-any.whl.metadata (15 kB)
Collecting python-dotenv<2.0.0,>=1.0.1 (from sae-lens)
Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (6.0.2)
Collecting pyzmq==26.0.0 (from sae-lens)
Downloading pyzmq-26.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.2 kB)
Collecting safetensors<0.5.0,>=0.4.2 (from sae-lens)
Downloading safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Requirement already satisfied: simple-parsing<0.2.0,>=0.1.6 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (0.1.7)
Collecting transformer-lens<3.0.0,>=2.0.0 (from sae-lens)
Downloading transformer_lens-2.15.0-py3-none-any.whl.metadata (12 kB)
Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (4.47.0)
Collecting typer<0.13.0,>=0.12.3 (from sae-lens)
Downloading typer-0.12.5-py3-none-any.whl.metadata (15 kB)
Requirement already satisfied: typing-extensions<5.0.0,>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from sae-lens) (4.13.2)
Collecting zstandard<0.23.0,>=0.22.0 (from sae-lens)
Downloading zstandard-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.9 kB)
Collecting blobfile<3.0.0,>=2.1.1 (from automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading blobfile-2.1.1-py3-none-any.whl.metadata (15 kB)
Collecting boostedblob<0.16.0,>=0.15.3 (from automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading boostedblob-0.15.6-py3-none-any.whl.metadata (2.0 kB)
Collecting httpx<0.28.0,>=0.27.0 (from automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)
Collecting numpy<2.0.0,>=1.24.0 (from automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m61.0/61.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: orjson<4.0.0,>=3.10.1 in /usr/local/lib/python3.11/dist-packages (from automated-interpretability<1.0.0,>=0.0.5->sae-lens) (3.10.16)
Requirement already satisfied: scikit-learn<2.0.0,>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from automated-interpretability<1.0.0,>=0.0.5->sae-lens) (1.6.1)
Collecting tiktoken>=0.6.0 (from automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from babe<0.0.8,>=0.0.7->sae-lens) (2.2.2)
Collecting py2store (from babe<0.0.8,>=0.0.7->sae-lens)
Downloading py2store-0.1.20.tar.gz (143 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââ[0m [32m143.1/143.1 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25h Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting graze (from babe<0.0.8,>=0.0.7->sae-lens)
Downloading graze-0.1.29-py3-none-any.whl.metadata (6.7 kB)
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.18.0)
Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (18.1.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets<3.0.0,>=2.17.1->sae-lens)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)
Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (4.67.1)
Collecting xxhash (from datasets<3.0.0,>=2.17.1->sae-lens)
Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets<3.0.0,>=2.17.1->sae-lens)
Downloading multiprocess-0.70.18-py311-none-any.whl.metadata (7.5 kB)
Collecting fsspec<=2024.6.1,>=2023.1.0 (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens)
Downloading fsspec-2024.6.1-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.11.15)
Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.30.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (24.2)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.3.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.57.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.8)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (11.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.2.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.8.2)
Requirement already satisfied: traitlets in /usr/local/lib/python3.11/dist-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.7.1)
Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.8)
Requirement already satisfied: joblib in /usr/local/lib/python3.11/dist-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)
Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.11/dist-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.11.6)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.11/dist-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (9.1.2)
Requirement already satisfied: statsmodels>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.4)
Requirement already satisfied: scipy>=0.18 in /usr/local/lib/python3.11/dist-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.14.1)
Requirement already satisfied: patsy>=0.5 in /usr/local/lib/python3.11/dist-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.0.1)
Requirement already satisfied: six in /usr/local/lib/python3.11/dist-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.17.0)
Requirement already satisfied: pytest in /usr/local/lib/python3.11/dist-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (8.3.5)
Collecting gprof2dot (from pytest-profiling<2.0.0,>=1.7.0->sae-lens)
Downloading gprof2dot-2025.4.14-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: docstring-parser<1.0,>=0.15 in /usr/local/lib/python3.11/dist-packages (from simple-parsing<0.2.0,>=0.1.6->sae-lens) (0.16)
Requirement already satisfied: accelerate>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (1.5.2)
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.8.1)
Collecting fancy-einsum>=0.0.3 (from transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading jaxtyping-0.3.1-py3-none-any.whl.metadata (7.0 kB)
Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (13.9.4)
Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.2.0)
Requirement already satisfied: torch>=2.2 in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (2.6.0+cu124)
Collecting transformers-stream-generator<0.0.6,>=0.0.5 (from transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading transformers-stream-generator-0.0.5.tar.gz (13 kB)
Preparing metadata (setup.py) ... [?25l[?25hdone
Requirement already satisfied: typeguard<5.0,>=4.2 in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (4.4.2)
Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.11/dist-packages (from transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.19.9)
Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.21.1)
Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)
Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate>=0.23.0->transformer-lens<3.0.0,>=2.0.0->sae-lens) (5.9.5)
Collecting pycryptodomex~=3.8 (from blobfile<3.0.0,>=2.1.1->automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading pycryptodomex-3.22.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Requirement already satisfied: urllib3<3,>=1.25.3 in /usr/local/lib/python3.11/dist-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (2.3.0)
Collecting lxml~=4.9 (from blobfile<3.0.0,>=2.1.1->automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading lxml-4.9.4-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.7 kB)
Collecting uvloop>=0.16.0 (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<1.0.0,>=0.0.5->sae-lens)
Downloading uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (2.6.1)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.2)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.4.3)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (0.3.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.19.0)
Requirement already satisfied: anyio in /usr/local/lib/python3.11/dist-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (4.9.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (2025.1.31)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (1.0.8)
Requirement already satisfied: idna in /usr/local/lib/python3.11/dist-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (3.10)
Requirement already satisfied: sniffio in /usr/local/lib/python3.11/dist-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (1.3.1)
Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (0.14.0)
Collecting wadler-lindig>=0.1.3 (from jaxtyping>=0.2.11->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading wadler_lindig-0.1.5-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->babe<0.0.8,>=0.0.7->sae-lens) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->babe<0.0.8,>=0.0.7->sae-lens) (2025.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.6.0->transformer-lens<3.0.0,>=2.0.0->sae-lens) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.6.0->transformer-lens<3.0.0,>=2.0.0->sae-lens) (2.18.0)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0.0,>=1.2.0->automated-interpretability<1.0.0,>=0.0.5->sae-lens) (3.6.0)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (3.1.6)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.6.2)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (12.4.127)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens)
Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (3.2.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (1.3.0)
Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.4.0)
Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (3.1.44)
Requirement already satisfied: platformdirs in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (4.3.7)
Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (5.29.4)
Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (2.11.3)
Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (2.26.1)
Requirement already satisfied: setproctitle in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (1.3.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (75.2.0)
Collecting dol (from graze->babe<0.0.8,>=0.0.7->sae-lens)
Downloading dol-0.3.16-py3-none-any.whl.metadata (18 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
Collecting multiprocess (from datasets<3.0.0,>=2.17.1->sae-lens)
Downloading multiprocess-0.70.17-py311-none-any.whl.metadata (7.2 kB)
Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting config2py (from py2store->babe<0.0.8,>=0.0.7->sae-lens)
Downloading config2py-0.1.37-py3-none-any.whl.metadata (14 kB)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.11/dist-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.5.2)
Requirement already satisfied: iniconfig in /usr/local/lib/python3.11/dist-packages (from pytest->pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2.1.0)
Requirement already satisfied: pluggy<2,>=1.5 in /usr/local/lib/python3.11/dist-packages (from pytest->pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.5.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (4.0.12)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.1.2)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.7.0)
Requirement already satisfied: pydantic-core==2.33.1 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (2.33.1)
Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from pydantic<3->wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (0.4.0)
Collecting i2 (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens)
Downloading i2-0.1.46-py3-none-any.whl.metadata (2.1 kB)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=2.2->transformer-lens<3.0.0,>=2.0.0->sae-lens) (3.0.2)
Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens<3.0.0,>=2.0.0->sae-lens) (5.0.2)
Downloading sae_lens-5.9.1-py3-none-any.whl (131 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m131.3/131.3 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyzmq-26.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (920 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m920.1/920.1 kB[0m [31m46.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading automated_interpretability-0.0.9-py3-none-any.whl (66 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m66.5/66.5 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading babe-0.0.7-py3-none-any.whl (6.9 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m527.3/527.3 kB[0m [31m43.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading plotly_express-0.4.1-py2.py3-none-any.whl (2.9 kB)
Downloading pytest_profiling-1.8.1-py3-none-any.whl (9.9 kB)
Downloading python_dotenv-1.1.0-py3-none-any.whl (20 kB)
Downloading safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m435.0/435.0 kB[0m [31m39.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading transformer_lens-2.15.0-py3-none-any.whl (189 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m189.2/189.2 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading typer-0.12.5-py3-none-any.whl (47 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m47.3/47.3 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading zstandard-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m5.4/5.4 MB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m739.7/739.7 kB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading better_abc-0.0.3-py3-none-any.whl (3.5 kB)
Downloading blobfile-2.1.1-py3-none-any.whl (73 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m73.7/73.7 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading boostedblob-0.15.6-py3-none-any.whl (59 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m59.2/59.2 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m116.3/116.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Downloading fsspec-2024.6.1-py3-none-any.whl (177 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m177.6/177.6 kB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading httpx-0.27.2-py3-none-any.whl (76 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m76.4/76.4 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxtyping-0.3.1-py3-none-any.whl (55 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m55.3/55.3 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m18.3/18.3 MB[0m [31m110.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m1.2/1.2 MB[0m [31m69.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m13.8/13.8 MB[0m [31m84.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m24.6/24.6 MB[0m [31m65.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m883.7/883.7 kB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m664.8/664.8 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m211.5/211.5 MB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m56.3/56.3 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m127.9/127.9 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m207.5/207.5 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m21.1/21.1 MB[0m [31m104.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gprof2dot-2025.4.14-py3-none-any.whl (37 kB)
Downloading graze-0.1.29-py3-none-any.whl (19 kB)
Downloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m143.5/143.5 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m194.8/194.8 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lxml-4.9.4-cp311-cp311-manylinux_2_28_x86_64.whl (7.9 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m7.9/7.9 MB[0m [31m125.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pycryptodomex-3.22.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.3 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m2.3/2.3 MB[0m [31m75.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.0 MB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m4.0/4.0 MB[0m [31m113.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.5-py3-none-any.whl (20 kB)
Downloading config2py-0.1.37-py3-none-any.whl (32 kB)
Downloading dol-0.3.16-py3-none-any.whl (254 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m254.7/254.7 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading i2-0.1.46-py3-none-any.whl (202 kB)
[2K [90mââââââââââââââââââââââââââââââââââââââââ[0m [32m202.8/202.8 kB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: transformers-stream-generator, py2store
Building wheel for transformers-stream-generator (setup.py) ... [?25l[?25hdone
Created wheel for transformers-stream-generator: filename=transformers_stream_generator-0.0.5-py3-none-any.whl size=12426 sha256=e92a3185e43fef8d02bc6e05ef3f5f82930caacae828f204407c128afbc850f1
Stored in directory: /root/.cache/pip/wheels/23/e8/f0/b3c58c12d1ffe60bcc8c7d121115f26b2c1878653edfca48db
Building wheel for py2store (setup.py) ... [?25l[?25hdone
Created wheel for py2store: filename=py2store-0.1.20-py3-none-any.whl size=118411 sha256=83833241be63fc0cff3fcab685cead9e1fd53e1c14edae95e258347c43fde023
Stored in directory: /root/.cache/pip/wheels/b6/c0/a9/8ba28129562790a3ba62e4de4241dac5474df73d0e8a64e27a
Successfully built transformers-stream-generator py2store
Installing collected packages: i2, dol, better-abc, zstandard, xxhash, wadler-lindig, uvloop, safetensors, pyzmq, python-dotenv, pycryptodomex, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, lxml, gprof2dot, fsspec, fancy-einsum, dill, config2py, beartype, tiktoken, pytest-profiling, py2store, nvidia-cusparse-cu12, nvidia-cudnn-cu12, multiprocess, jaxtyping, httpx, graze, blobfile, typer, nvidia-cusolver-cu12, boostedblob, babe, plotly-express, datasets, automated-interpretability, transformers-stream-generator, transformer-lens, sae-lens
Attempting uninstall: zstandard
Found existing installation: zstandard 0.23.0
Uninstalling zstandard-0.23.0:
Successfully uninstalled zstandard-0.23.0
Attempting uninstall: safetensors
Found existing installation: safetensors 0.5.3
Uninstalling safetensors-0.5.3:
Successfully uninstalled safetensors-0.5.3
Attempting uninstall: pyzmq
Found existing installation: pyzmq 24.0.1
Uninstalling pyzmq-24.0.1:
Successfully uninstalled pyzmq-24.0.1
Attempting uninstall: nvidia-nvjitlink-cu12
Found existing installation: nvidia-nvjitlink-cu12 12.5.82
Uninstalling nvidia-nvjitlink-cu12-12.5.82:
Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82
Attempting uninstall: nvidia-curand-cu12
Found existing installation: nvidia-curand-cu12 10.3.6.82
Uninstalling nvidia-curand-cu12-10.3.6.82:
Successfully uninstalled nvidia-curand-cu12-10.3.6.82
Attempting uninstall: nvidia-cufft-cu12
Found existing installation: nvidia-cufft-cu12 11.2.3.61
Uninstalling nvidia-cufft-cu12-11.2.3.61:
Successfully uninstalled nvidia-cufft-cu12-11.2.3.61
Attempting uninstall: nvidia-cuda-runtime-cu12
Found existing installation: nvidia-cuda-runtime-cu12 12.5.82
Uninstalling nvidia-cuda-runtime-cu12-12.5.82:
Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82
Attempting uninstall: nvidia-cuda-nvrtc-cu12
Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82
Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:
Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82
Attempting uninstall: nvidia-cuda-cupti-cu12
Found existing installation: nvidia-cuda-cupti-cu12 12.5.82
Uninstalling nvidia-cuda-cupti-cu12-12.5.82:
Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82
Attempting uninstall: nvidia-cublas-cu12
Found existing installation: nvidia-cublas-cu12 12.5.3.2
Uninstalling nvidia-cublas-cu12-12.5.3.2:
Successfully uninstalled nvidia-cublas-cu12-12.5.3.2
Attempting uninstall: numpy
Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
Successfully uninstalled numpy-2.0.2
Attempting uninstall: lxml
Found existing installation: lxml 5.3.2
Uninstalling lxml-5.3.2:
Successfully uninstalled lxml-5.3.2
Attempting uninstall: fsspec
Found existing installation: fsspec 2025.3.2
Uninstalling fsspec-2025.3.2:
Successfully uninstalled fsspec-2025.3.2
Attempting uninstall: nvidia-cusparse-cu12
Found existing installation: nvidia-cusparse-cu12 12.5.1.3
Uninstalling nvidia-cusparse-cu12-12.5.1.3:
Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3
Attempting uninstall: nvidia-cudnn-cu12
Found existing installation: nvidia-cudnn-cu12 9.3.0.75
Uninstalling nvidia-cudnn-cu12-9.3.0.75:
Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75
Attempting uninstall: httpx
Found existing installation: httpx 0.28.1
Uninstalling httpx-0.28.1:
Successfully uninstalled httpx-0.28.1
Attempting uninstall: typer
Found existing installation: typer 0.15.2
Uninstalling typer-0.15.2:
Successfully uninstalled typer-0.15.2
Attempting uninstall: nvidia-cusolver-cu12
Found existing installation: nvidia-cusolver-cu12 11.6.3.83
Uninstalling nvidia-cusolver-cu12-11.6.3.83:
Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2024.6.1 which is incompatible.
langsmith 0.3.31 requires zstandard<0.24.0,>=0.23.0, but you have zstandard 0.22.0 which is incompatible.
google-genai 1.10.0 requires httpx<1.0.0,>=0.28.1, but you have httpx 0.27.2 which is incompatible.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0mSuccessfully installed automated-interpretability-0.0.9 babe-0.0.7 beartype-0.14.1 better-abc-0.0.3 blobfile-2.1.1 boostedblob-0.15.6 config2py-0.1.37 datasets-2.21.0 dill-0.3.8 dol-0.3.16 fancy-einsum-0.0.3 fsspec-2024.6.1 gprof2dot-2025.4.14 graze-0.1.29 httpx-0.27.2 i2-0.1.46 jaxtyping-0.3.1 lxml-4.9.4 multiprocess-0.70.16 numpy-1.26.4 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 plotly-express-0.4.1 py2store-0.1.20 pycryptodomex-3.22.0 pytest-profiling-1.8.1 python-dotenv-1.1.0 pyzmq-26.0.0 sae-lens-5.9.1 safetensors-0.4.5 tiktoken-0.9.0 transformer-lens-2.15.0 transformers-stream-generator-0.0.5 typer-0.12.5 uvloop-0.21.0 wadler-lindig-0.1.5 xxhash-3.5.0 zstandard-0.22.0
from sae_lens import SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(
release = "gemma-scope-2b-pt-res-canonical",
sae_id = "layer_20/width_16k/canonical",
)
print(sae, cfg_dict, sparsity)
params.npz: 0%| | 0.00/302M [00:00<?, ?B/s]
SAE(
(activation_fn): ReLU()
(hook_sae_input): HookPoint()
(hook_sae_acts_pre): HookPoint()
(hook_sae_acts_post): HookPoint()
(hook_sae_output): HookPoint()
(hook_sae_recons): HookPoint()
(hook_sae_error): HookPoint()
) {'architecture': 'jumprelu', 'd_in': 2304, 'd_sae': 16384, 'dtype': 'float32', 'model_name': 'gemma-2-2b', 'hook_name': 'blocks.20.hook_resid_post', 'hook_layer': 20, 'hook_head_index': None, 'activation_fn_str': 'relu', 'finetuning_scaling_factor': False, 'sae_lens_training_version': None, 'prepend_bos': True, 'dataset_path': 'monology/pile-uncopyrighted', 'context_size': 1024, 'dataset_trust_remote_code': True, 'apply_b_dec_to_input': False, 'normalize_activations': None, 'device': 'cpu', 'neuronpedia_id': 'gemma-2-2b/20-gemmascope-res-16k'} None
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
return html_template.format(sae_release, sae_id, feature_idx)
########################## TODO (Q7.1) ############################
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=10004)
IFrame(html, width=1200, height=600)
###################################################################
Q7.2~7.3: Maximum activations comparison
######################## (Q7.2 ~ 7.3) ########################
def get_max_activation(model, tokenizer, sae, prompt, feature_idx=10004):
"""
Computes the maximum activation of a specific feature in a Sparse Autoencoder (SAE)
for a given prompt.
Args:
model: The Transformer model used for generating hidden states.
tokenizer: The tokenizer for encoding the prompt.
sae: The Sparse Autoencoder for encoding hidden states.
prompt (str): The input text prompt.
feature_idx (int, optional): The index of the feature in SAE. Defaults to 10004.
Returns:
float: The maximum activation value for the specified feature index.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
sae.to(device)
# Tokenize the input prompt and get model outputs
tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model(tokens, output_hidden_states=True)
# Extract hidden states from the specified layer
hidden_states = outputs.hidden_states[sae.cfg.hook_layer]
# Encode hidden states using SAE
sae_in = hidden_states
feature_acts = sae.encode(sae_in).squeeze() # Shape: (batch_size * seq_len, num_features)
feature_acts = feature_acts.reshape(-1, feature_acts.shape[-1])
# Compute max activation for the specified feature index
max_activation = -float("inf")
batch_max_activation = feature_acts[:, feature_idx].max().item()
max_activation = max(max_activation, batch_max_activation)
# Plot activation distribution
plt.figure(figsize=(8, 5))
plt.hist(feature_acts[:, feature_idx].cpu().detach().numpy(), bins=50, alpha=0.75, color='blue', edgecolor='black')
plt.xlabel(f"Activation values (Feature {feature_idx})")
plt.ylabel("Frequency")
plt.title(f"Activation Distribution for Feature {feature_idx} - Prompt: '{prompt}'")
plt.grid(True)
plt.show()
return max_activation
feature_idx = 10004
# Define the prompts
prompt_a = "Time travel offers me the opportunity to correct past errors, but it comes with its own set of risks."
prompt_b = "I accept that my decisions shape my future, and though mistakes are inevitable, they define who I become."
# Calculate the maximum activations for each prompt using the feature index
max_activation_a = get_max_activation(model, tokenizer, sae, prompt_a, feature_idx=feature_idx)
max_activation_b = get_max_activation(model, tokenizer, sae, prompt_b, feature_idx=feature_idx)
# Print the comparison
print(f"max_activation for prompt_a: {max_activation_a}")
print(f"max_activation for prompt_b: {max_activation_b}")
###########################################################
max_activation for prompt_a: 58.03689193725586
max_activation for prompt_b: 31.76049041748047
Q7.4~7.6: Activation distribution for specific layer
import numpy as np
import matplotlib.pyplot as plt
def plot_token_activations(model, tokenizer, sae, prompt, feature_idx=10004, layer_idx=0):
"""
Plots activations for each token in a specific layer.
Args:
model: The transformer model.
tokenizer: Tokenizer for encoding input text.
sae: Sparse Autoencoder model.
prompt: Input text string.
feature_idx: Index of the feature to analyze.
layer_idx: Layer to analyze (None uses sae.cfg.hook_layer).
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
sae.to(device)
# Tokenize input and get model output
tokens = tokenizer(prompt, return_tensors="pt")
token_ids = tokens["input_ids"].to(device)
token_list = tokenizer.convert_ids_to_tokens(token_ids.squeeze().tolist())
outputs = model(token_ids, output_hidden_states=True)
# Choose layer
layer_idx = layer_idx if layer_idx is not None else sae.cfg.hook_layer
hidden_states = outputs.hidden_states[layer_idx]
# Pass through SAE
sae_in = hidden_states
feature_acts = sae.encode(sae_in).squeeze() # (batch_size, seq_len, num_features)
print(f"feature_acts shape: {feature_acts.shape}")
# Extract activations for the chosen feature
activations = feature_acts[:, feature_idx].squeeze().cpu().detach().numpy()
# Plot
plt.figure(figsize=(10, 5))
plt.bar(range(len(token_list)), activations, color='blue', alpha=0.7)
plt.xticks(range(len(token_list)), token_list, rotation=45)
plt.xlabel("Tokens")
plt.ylabel(f"Activation Value (Feature {feature_idx})")
plt.title(f"Token-wise Activations for Layer {layer_idx}")
plt.grid(True)
plt.show()
######################## (Q7.4 ~ 7.6) ########################
# Simply observe the figure
layer_idx = 24
prompt = "Time travel will become a reality as technology continues to advance."
plot_token_activations(model, tokenizer, sae, prompt, feature_idx, layer_idx)
###################################################################
feature_acts shape: torch.Size([13, 16384])
Q7.7~7.9: Activation distribution for specific token
def plot_layer_activations(model, tokenizer, sae, prompt, token_idx=0, feature_idx=10004):
"""
Plots activations of a specific token across all layers.
Args:
model: The transformer model.
tokenizer: Tokenizer for encoding input text.
sae: Sparse Autoencoder model.
prompt: Input text string.
token_idx: Index of the token to analyze.
feature_idx: Index of the feature to analyze.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
sae.to(device)
# Tokenize input and get model output
tokens = tokenizer(prompt, return_tensors="pt")
token_ids = tokens["input_ids"].to(device)
token_list = tokenizer.convert_ids_to_tokens(token_ids.squeeze().tolist())
outputs = model(token_ids, output_hidden_states=True)
# Collect activations across all layers
num_layers = len(outputs.hidden_states)
activations = []
for layer_idx in range(num_layers):
hidden_states = outputs.hidden_states[layer_idx]
sae_in = hidden_states
feature_acts = sae.encode(sae_in).squeeze() # (batch_size, seq_len, num_features)
# print(f"feature_acts shape: {feature_acts.shape}")
activations.append(feature_acts[token_idx, feature_idx].item())
# Plot
plt.figure(figsize=(8, 5))
plt.plot(range(num_layers), activations, marker="o", linestyle="-", color="blue")
plt.xlabel("Layer")
plt.ylabel(f"Activation Value (Feature {feature_idx})")
plt.title(f"Activation Across Layers for Token '{token_list[token_idx]}'")
plt.xticks(range(num_layers))
plt.grid(True)
plt.show()
######################## (Q7.7 ~ 7.9) ########################
# Alter the token index to observe the figure
token_idx = 1
prompt = "Time travel will become a reality as technology continues to advance."
plot_layer_activations(model, tokenizer, sae, prompt, token_idx)
###################################################################