Another post about one of the pieces of DALL-E, the vision transformer (Dosovitskiy et al. 2021), which is specifically used as one of the vision encoders in CLIP (Radford et al. 2021), which is itself used to condition the decoding process in DALL-E.
The vision transformer is at its base really not all that complicated: An image is divided into patches, and each patch is flattened, sent through an embedding layer, and arranged in a sequence to provide the expected input to a stack of multi-headed-attention + position-wise feed-forward layers identical to the architectures used to process text. Somewhat miraculously, with enough data, the vision transformer learns the spatial nature of the input even without the inductive biases of convolutional networks.
Here’s a link to the colab notebook for training this thing up.
I did roughly the following:
- Essentially copy the transformer encoder architecture from harvard nlp
- Prepend the patch creation and embedding plus learnable positional embeddings
- Add the classification token and classification head as described in (Dosovitskiy et al. 2021)
Components added to base transformer architecture
I implemented an image patcher like so:
import torch
import einops
class imagePatches(torch.nn.Module):
def __init__(self, patch_size=(8,8), input_channels = 3, stride = 8, embedding_dim=768):
super().__init__()
self.patch_size = patch_size
self.unfold = torch.nn.Unfold(patch_size, stride = stride)
self.patch_embedding = torch.nn.Linear(
0]*patch_size[1]*input_channels,
patch_size[
embedding_dim
)
def forward(self, img):
= self.unfold(img)
patches = einops.rearrange(patches, "b c p -> b p c")
patches
= self.patch_embedding(patches)
embeddings return embeddings, patches
The classification head looked like:
# ... blah blah module super init
def forward(self, x:torch.tensor, **kwargs):
= encoder(x, **kwargs)
x # x = Reduce("b p c -> b c", reduction='mean')(x) #
# x = self.layernorm(x)
= self.layernorm(x[:,0])
x = self.hidden(x)
x = torch.tanh(x)
x = self.classification(x) x
I switched back and forth between the mean pooling (commented out) and using only the classification token output before going through the classification layer.
Just to show it, here’s the class embedding + learnable positional embedding inside the encoder class:
class vitEncoder(torch.nn.Module):
def __init__(self, n_layers, embedding_dim, image_tokenizer, mha_layer, ff_layer, n_patches):
super().__init__()
self.image_tokenizer=image_tokenizer
self.positional_embedding = torch.nn.Parameter(
+ 1, embedding_dim))
torch.randn((n_patches
)self.embedding_dim = embedding_dim
self.n_layers = n_layers
self.encoder_layers = torch.nn.ModuleList([
for _ in range(n_layers)
vitEncoderLayer(copy.deepcopy(mha_layer), copy.deepcopy(ff_layer))
])self.class_embedding = torch.nn.Parameter(torch.randn((1, 1, embedding_dim)))
def forward(self, x:torch.tensor, attn_mask:torch.tensor=None):
= self.image_tokenizer(x)
x, _
= torch.concatenate([
x self.class_embedding, "b t c -> (r b) t c", r = x.shape[0]), x],
einops.repeat(= 1 # patch sequence axis
axis
)
= x + self.positional_embedding
x
for l in self.encoder_layers:
= l(x, attn_mask)
x
return(x)
Training
First attempt:
Epoch 0, Loss is 2.305
Epoch 0, Loss is 2.367
Epoch 0, Loss is 2.301
Epoch 0, Loss is 2.298
...
Epoch 10, Loss is 2.300
Epoch 10, Loss is 2.303
Epoch 10, Loss is 2.304
Epoch 10, Loss is 2.312
… shit. Oh right the positional embeddings.
Epoch 0, Loss is 2.301
Epoch 0, Loss is 2.309
Epoch 0, Loss is 2.297
Epoch 0, Loss is 2.308
...
Epoch 10, Loss is 2.299
Epoch 10, Loss is 2.304
Epoch 10, Loss is 2.311
Epoch 10, Loss is 2.303
…damnit. Oh whoops I missed a couple activations.
Epoch 0, Loss is 2.331
Epoch 0, Loss is 2.299
Epoch 0, Loss is 2.297
Epoch 0, Loss is 2.312
...
Epoch 10, Loss is 2.291
Epoch 10, Loss is 2.303
Epoch 10, Loss is 2.312
Epoch 10, Loss is 2.301
…shit…
…
*Reminds self this is an EXERCISE and pain is expected.* I did eventually get this working, but first a couple bugs I found along the way:
- Two very dumb mistakes implementing the layer normalization (After fixing it I switched to just using
torch.nn.LayerNorm
)
class LayerNorm(nn.Module):
def __init__(self, features):
super().__init__()
self.w = nn.Parameter(torch.ones(features))
self.b = nn.Parameter(torch.zeros(features))
def forward(self, x, eps=1e-6):
return self.w * x.mean(-1, keepdim=True) / (x.std(-1, keepdim=True) + eps) + self.b
You see it? Yea I’m not actually mean-centering in the numerator there…
...
def forward(self, x, eps=1e-6):
self.w * (x - x.mean(-1, keepdim=True)/(x.std(-1, keepdim=True) + eps)) + self.b
And here I’ve just fudged the parentheses…squint a bit and you’ll see.
- MultiheadedAttention expects sequence dimension first? Apparently if you dont specify
batch_first = True
:
= torch.nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=n_head, batch_first = True) attention_layer
Then the MultiheadAttention
treats the first dimension as the sequence dimension of the QKV input…this fails silently since the self-attention matrix multiplication is still valid.
Long story short, I did get the model to train up modestly on MNIST (I’ll try on the other two datasets later). You can see my successful training run against all others here:
In the end it was using Adam over SGD in training that got me to see proper training. SGD is known to be very picky about the learning rate. I might try some sweeps over various LR’s and report back.
Another (probably related) thing I ran into is that averaging all the outputs of the last layer resulted in some learning with SGD, but taking only the classification token output resulted in zero learning. Probably the learning rate(s) I was using were closer to being appropriate for averaging than with taking just the classification token output.