Hi everyone,
I recently saw a post about Comgra in this sub-reddit (check it out, it’s very cool!), which inspired me to share a similar pet project of mine - mine has a slightly different goal.
I’ve been very interested in understanding how LLMs work behind the scenes, and I’ve started reading a bunch of cool papers like:
- What Does BERT Look At? An Analysis of BERT’s Attention: It shows how attention heads learn very specific functions (like finding the direct object, coreference, etc…)
- Transformer Feed-Forward Layers Are Key-Value Memories: Shows how memories are encoded in transformer feed-forward layers and how to extract them, very cool!!
I find it amazing that we have powerful models today, but we are still learning to unveil the covers from the magical black boxes they are.
To this end, I’ve been working on a library/framework that aims to allow you to perform these analyses (and more) in a general fashion on any model (striving to include modern LLMs like Llama 2 too)
Here’s an example result showing the attention head clustering as showcased by the BERT paper, but this time for LLama 2 (in the future I hope to add the ability to name each attention head according to its function, perhaps using LLMs too):
Attention head clustering in Llama 2: Each point is an attention head, each color represents a layer
The library implements dynamic instrumentation of PyTorch functions/tensors that allows concise code like:
with fmrai.fmrai() as fmr:
m = instrument_model(model)
with fmr.track() as tracker:
m(**tokenizer("Hello World", return_tensors="pt"))
g = tracker.build_graph()
g.save_dot('graph.dot')
which gives:
You can then use the tools in library to find where the attention is, extract the tensors, run analyses, etc…
In the future, I hope to add support for image models and more.
Please note that this project is very early-state and not very stable at the moment, I hope someone can find it useful/interesting.