Well, 2018-2019 have been a really active couple of year for NLP, hasn’t it? BERT, Transformer-XL, XLNet and what not. The common working principle in all of these models is the transformer architecture proposed by Vaswani et al. But all these published models are very large is size and hard to tinker with. Even fine tuning the XLNet architecture needs 4 GPUs. These heavy headed models makes it difficult for a researcher to quickly build a small scale prototype and observe how things pan out. For the same reason, I also needed a model that can be trained on a single or dual gpu machine and I can experiment with the architecture.

Right now I have only implemented a basic version of Transformer-XL without the memory functionality. My plan is to first extend it with XLNet’s permutation language model. The complete code is available through a github repo: Transformer-XS for Xtra-Small ;).

This implementation is a modification of the code in the amazing tutorial in this medium post by TDS. Those who are new to transformer, I would suggest you to go through the above tutorial first. I will not explain the architecture in much detail in this post.

Open In Colab

Link to trained model: drive link

Setup GPU

We only need one GPU for training and inference.



import os, re, math, copy, time, sys
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pickle
import spacy
from collections import defaultdict
import subprocess
if not os.path.isfile('./utils.py'):
    print("Downloading utils.py...")
    url = "https://raw.githubusercontent.com/bsantraigi/Transformer-XS/master/utils.py"
    import subprocess
    subprocess.run(["wget", url])
    print("Found utils.py...")
Found utils.py...
from utils import Lang
print(f"# Using pytorch v{torch.__version__}")
# Using pytorch v1.1.0
MAX_LEN = 200

Check CUDA

Checking is cuda is available. I haven’t dared to trained this on CPU. Even just using GPU also takes quite a lot of time to train well. In case cuda isn’t detected in your system, you might not have a GPU or have the CPU variant of pytorch installed. You can always run this on Google Colab.

if torch.cuda.is_available():
    device = "cuda"
    device = "cpu"
print(f"# Using device: {device}")
# Using device: cuda

Download Wikitext-103 dataset

!mkdir data
mkdir: cannot create directory ‘data’: File exists
if not os.path.isdir("data/wikitext-103/"):
    print("Downloading data...")
    subprocess.run("wget -c https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -P data".split())
    print("Unzipping data...")
    subprocess.run(["unzip", "data/wikitext-103-v1.zip", "-d", "data/"])
    print("Found data...")
Found data...

Data Preprocessing

Main target here is to create the VOCAB

en = spacy.load('en_core_web_sm')
data_path = 'data/wikitext-103/'
train_lines = 1801350
test_lines = 4358
valid_lines = 3760
vocab = defaultdict(int)
split = 'train'
L = eval(f'{split}_lines')

Vocab Creation

# with open(data_path + f'wiki.{split}.tokens') as f:
#     _progress = 0
#     buffer = []
#     for line in tqdm_notebook(f, total=L):
#         # _progress += 1
#         line = buffer.append(line.strip())
#         # print(f'{_progress/L*100:2.2F}', end='\r')
#         if len(buffer) > 40000:
#             buffer = ' '.join(buffer)
#             tokens = list(en.tokenizer(buffer.lower()))
#             buffer = []
#             for w in tokens:
#                 vocab[w.text] += 1

#     # One last time to clean the buffer
#     buffer = ' '.join(buffer)
#     tokens = list(en.tokenizer(buffer.lower()))
#     buffer = []
#     for w in tokens:
#         vocab[w.text] += 1

Create or Load Vocab

The following step will take some time, upto 10 mins. The spacy tokenizer is not as fast. But this is a one time process. Once the vocab file is created, you can just load from there.


Loads a saved vocab class object with word2index and index2word functions.

lang_file = "./models/wiki103.large.lang"
if not os.path.isfile(lang_file):
    print("Creating vocab file...")
    en_lang = Lang('wiki')
    en_lang.buildLang(open(data_path + f'wiki.{split}.tokens'), num_lines=train_lines)
    with open(lang_file, 'wb') as f:
        pickle.dump(f, en_lang)
    print("Loading vocab file...")
    en_lang = pickle.load(open('./models/wiki103.large.lang', 'rb'))
Loading vocab file...
CPU times: user 155 ms, sys: 64.3 ms, total: 219 ms
Wall time: 190 ms

Limit vocab size

We only consider a vocab size of 40000 for now. This version of model is based on English words seen in training dataset. To decrease number of in dataset, I kept the vocab size a bit large. If we use bpe, the vocab size can be decreased while keeping better coverage.



Holds the word embedding matrix.

class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        self.embed = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embed(x)

Positional Encoder

Transformer doesn’t have any sequential notion in it’s architecture by default. So, it can only realize it’s input as a bag of tokens. So, we need to explicitly provide positional information through the token embedding itself.

class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len = MAX_LEN):
        self.d_model = d_model

        # create constant 'pe' matrix with values dependant on
        # pos and i
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = \
                math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = \
                math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # make embeddings relatively larger
        x = x * math.sqrt(self.d_model)
        #add constant to embedding
        seq_len = x.size(1)
        x = x + Variable(self.pe[:,:seq_len], \
        return x

Multi-Head Attention

This is core part of the transformer architecture. A single layer of Multi-Head Attention applies self-attention to all of it’s inputs. The input of this operation is a bag of k tokens (each with it’s representation of query, key and value) and output is updated representation of the k tokens again. Based on the query representation of every token, one first decide weights (or attention) for key representation of all other tokens. The updated output representation of the query token is constructed by taking linear combination of value representation of tokens using the weights calculated.

In this implementation, we only look behind the current location by masking the indices ahead. This is because we want to predict the next word conditioned on the context behind.

I plan to add the permutation language model functionality based on XLNet to allow learning bidirectional features.

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):

        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):

        bs = q.size(0)

        # perform linear operation and split into h heads

        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # transpose to get dimensions bs * h * sl * d_model

        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)

        # calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k, mask, self.dropout)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
        .view(bs, -1, self.d_model)

        output = self.out(concat)

        return output
def attention(q, k, v, d_k, mask=None, dropout=None):
    # q, k, v : shape(bs, heads, L_max, d_k)
    # scores: matmul [shape(bs,heads,L_max,d_k), shape(bs,heads,d_k,L_max)] -> shape(bs,heads,L_max,L_max)
    # scores x v : shape(bs,heads,L_max,L_max) X shape(bs,heads,L_max,d_k) -> shape(bs,heads,L_max,d_k)

    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    if mask is not None:
        # print(f"### Shape of pre-softmax logits: {scores.shape}")
        # mask = mask.unsqueeze(1)
        # print(f"### Shape of mask: {mask.shape}")
        scores = scores.masked_fill(mask == 0, -1e9)
    scores = F.softmax(scores, dim=-1)

    if dropout is not None:
        scores = dropout(scores)

    output = torch.matmul(scores, v)
    return output


A simple feed forward network with one hidden layer. Input and output dimensions are d_model and hidden layer size is d_ff (=2048 by default).

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout = 0.1):
        # We set d_ff as a default to 2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

Layer Norm

Following the trend in various papers, we also apply Layer Norm after every Multi Head attention and Feed Forward layers.

class Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-6):

        self.size = d_model
        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

Encoder Layer

Puts together a single layer of the Encoder. This applies [LayerNorm -> Multi-Head Attn -> LayerNorm -> Feed Forward] to the input.

# build a decoder layer with two multi-head attention layers and
# one feed-forward layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=0.1):
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)

        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

        self.attn_1 = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model)

    def forward(self, x, trg_mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ff(x2))
        return x

# We can then build a convenient cloning function that can generate multiple layers:
def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


Puts together the whole network by stacking

  • Word Embedding Matrix
  • Positional Encoder
  • N multihead attention layers
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads):
        self.N = N
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)
        self.layers = get_clones(EncoderLayer(d_model, heads), N)
        self.norm = Norm(d_model)
    def forward(self, trg, trg_mask):
        x = self.embed(trg)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, trg_mask)
        return self.norm(x)


Final wrapper class for Transformer. Nothing but the Encoder layer along with a final linear projection layer, that projects the output representation to log probability of words in vocab.

class Transformer(nn.Module):
    def __init__(self, trg_vocab, d_model, N, heads):
        self.encoder = Encoder(trg_vocab, d_model, N, heads)
        # self.decoder = Decoder(trg_vocab, d_model, N, heads)
        self.out = nn.Linear(d_model, trg_vocab)
    def forward(self, trg, trg_mask):
        # e_outputs = self.encoder(src, src_mask)
        d_output = self.encoder(trg, trg_mask)
        output = self.out(d_output)
        return output
        # we don't perform softmax on the output as this will be handled
        # automatically by our loss function


WikiDataset class for fetching samples from wikitext-103 dataset.

class WikiDataset(Dataset):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    def __init__(self, split, max_len=MAX_LEN):
        super(WikiDataset, self).__init__()
        if split == 'train':
            _file = data_path + '/wiki.train.tokens'
            n_lines = 1801350
        elif split=="valid":
            _file = data_path + '/wiki.valid.tokens'
            n_lines = 3760
        elif split=="test":
            _file = data_path + '/wiki.test.tokens'
            n_lines = 4358
            raise Exception(f"wrong split: {split}")
        print("File:", _file)
        print("Expected # of lines:", n_lines)
        self.data = []
        with open(_file) as f:
            for line in tqdm_notebook(f, total=n_lines):
                line = line.strip()
                if len(line) > 0:
                    el = en_lang.encodeSentence(line)
                    if len(el) < max_len:
                        el = el + [en_lang.iEOS] + [en_lang.iPAD]*(max_len - len(el) - 1)
                        el = el[:(max_len - 1)] + [en_lang.iEOS]

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

Dataset Split

wikiDataset_valid = WikiDataset('valid')
wikiDataset_test = WikiDataset('test')
wikiDataset_train = WikiDataset('train')
File: data/wikitext-103//wiki.valid.tokens
Expected # of lines: 3760

HBox(children=(IntProgress(value=0, max=3760), HTML(value='')))

File: data/wikitext-103//wiki.test.tokens
Expected # of lines: 4358

HBox(children=(IntProgress(value=0, max=4358), HTML(value='')))

File: data/wikitext-103//wiki.train.tokens
Expected # of lines: 1801350

HBox(children=(IntProgress(value=0, max=1801350), HTML(value='')))

Model Config

  • d_model: Embedding dim of words
  • heads: Number of heads used for multi-head attention
  • N: Number of MHA layers
d_model = 512
heads = 32
N = 8
_vocab = en_lang.VOCAB_SIZE
model = Transformer(_vocab, d_model, N, heads)
for p in model.parameters():
    if p.dim() > 1:
# this code is very important! It initialises the parameters with a
# range of values that stops the signal fading or getting too big.
# See this blog for a mathematical explanation.
optim = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
_ = model.cuda()
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"## Training model with {pytorch_total_params/1000000:0.2F}M trainable parameters.")
## Training model with 66.22M trainable parameters.
wikiDataloader_train = DataLoader(wikiDataset_train, batch_size=32)
wikiDataloader_valid = DataLoader(wikiDataset_valid, batch_size=32)
wikiDataloader_test = DataLoader(wikiDataset_test, batch_size=32)
print(f"## Steps per epoch {len(wikiDataloader_train.dataset)//wikiDataloader_train.batch_size}")
## Steps per epoch 36407
def train_model():

    _ = model.train()
    start = time.time()
    temp = start

    total_loss = 0

    for epoch in range(epochs):
        for i, batch in enumerate(wikiDataloader_train):
            batch = torch.stack(batch).to(device)
            trg = batch.t()

            # the French sentence we input has all words except
            # the last, as it is using each word to predict the next
            trg_input = trg[:, :-1]

            # the words we are trying to predict
            targets = trg[:, 1:].contiguous().view(-1)

            # create mask to make sure attn reads input only from the left (autoregressive)
            trg_mask = torch.tensor(np.tril(
                    (1, 1, trg_input.shape[1], trg_input.shape[1]))
            ), device=device) * ((trg_input != en_lang.iPAD).double().unsqueeze(1).unsqueeze(1))

            preds = model(trg_input, trg_mask)


            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), targets, ignore_index=en_lang.iPAD)

            total_loss += loss.data.item()
            if (i + 1) % print_every == 0:
                loss_avg = total_loss / print_every
                print("time = %dm, epoch %d, iter = %d, loss = %.3f, PPL = %8.2f, %ds per %d iters" % ((time.time() - start) // 60,
                epoch + 1, i + 1, loss_avg, math.exp(loss_avg), time.time() - temp,
                total_loss = 0
                temp = time.time()
                # raise Exception("STOP")

Training the model

# train_model()

Save the model

# torch.save({
#     'epoch': epoch,
#     'iter': i,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optim.state_dict(),
#     'loss': loss
# }, "models/txl_wikitext103.pth")

Load a saved model

checkpoint = torch.load("models/txl_wikitext103.pth")
IncompatibleKeys(missing_keys=[], unexpected_keys=[])

Prediction on Test Set

The following function runs the model on the test or validation set. You can use this function to calculate perplexity on the validation or test set to compare. I didn’t bother doing this as I was more interested in the contextual text generation task. That’s the next function.

def sample_sequence():
    with torch.no_grad():
        start_from = 6
        for i, batch in enumerate(wikiDataloader_valid):
            batch = torch.stack(batch).to(device)
            trg = batch.t()
            zl = None
            trg_input = trg[:, :-1]
            trg_mask_common = torch.tensor(np.tril(
                        (1, 1, trg_input.shape[1], trg_input.shape[1]))
                ), device=device) * ((trg_input != en_lang.iPAD).double().unsqueeze(1).unsqueeze(1))

            # the words we are trying to predict
            targets = trg[:, 1:].contiguous().view(-1)

            for j in range(start_from, MAX_LEN - 1):
                # Predicting (j+1)th word
                if zl is None:
                    zl = torch.tensor(np.zeros((1, 1, trg_input.shape[1], trg_input.shape[1]))
                                      , device=device).double()
                    zl[..., :j] = 1

                zl[..., :j] = 1

                # create mask to make sure attn reads input only from the left (autoregressive)
                trg_mask =  trg_mask_common * zl

                preds = F.softmax(model(trg_input, trg_mask)[...,j,:], dim=-1)
#                 samples = torch.multinomial(preds, 1)[:,0]
                samples = torch.argmax(preds, 1)
                # samples = samples.view(trg_input.shape[0], -1)

                trg[..., (j+1)] = samples
                print(f"{j}", end="\r")
            return trg
preds = sample_sequence()
en_lang.decodeSentence(preds[16, :].cpu().tolist())

Text sampler

Finally, the text generator function. This is inspired by the talktotransformer site. I was blown away by that site. Of course, the model here trained is not as good as the fine-tuned GPT-2 model used for talktotransformer, but this gives a good flavour of the task.

def talk_to_me(context, max_len = MAX_LEN):
    context = torch.tensor(en_lang.encodeSentence(context)).unsqueeze(0).to(device)
    with torch.no_grad():
        start_from = (context.shape[1] - 1)
        # for i, batch in enumerate(wikiDataloader_valid):
        trg_input = context.to(device)
        # trg = batch.t()
        zl = None
        # trg_input = trg[:, :-1]
        trg_input = F.pad(trg_input, (0, MAX_LEN - trg_input.shape[1]), "constant", en_lang.iEOS)

        trg_mask_common = torch.tensor(np.tril(
                    (1, 1, MAX_LEN, MAX_LEN))
            ), device=device) * ((trg_input != en_lang.iPAD).double().unsqueeze(1).unsqueeze(1))

        for j in range(start_from, MAX_LEN - 1):
            # Predicting (j+1)th word
            if zl is None:
                zl = torch.tensor(np.zeros((1, 1, trg_input.shape[1], trg_input.shape[1]))
                                  , device=device).double()
                zl[..., :j] = 1

            zl[..., :j] = 1

            # create mask to make sure attn reads input only from the left (autoregressive)
            trg_mask =  trg_mask_common * zl

            preds = F.softmax(model(trg_input, trg_mask)[...,j,:], dim=-1)
            if np.random.rand() < 0.2:
                samples = torch.multinomial(preds, 1)[:,0]
                samples = torch.argmax(preds, 1)
            # samples = samples.view(trg_input.shape[0], -1)

            trg_input[..., (j+1)] = samples
            if samples.item() == en_lang.iEOS:
                return trg_input
            print(f"{j}", end="\r")
        return trg_input
query = "Bangalore has the best"
for i in range(5):
    gen_text = talk_to_me(query)
    print(f"Sample {i}: ", ' '.join(en_lang.decodeSentence(gen_text.cpu()[0].numpy().tolist())))
Sample 0:  bangalore has the best of the city s economy , with a total of 1 , <UNK> , unions and 3 , <UNK> . the city s economy is dominated by agriculture , which is often cool and dry . the economy is dominated by agriculture , agriculture and agriculture , and is governed by a <UNK> system of yielded industries . the city s economy is multi - ethnic , with a population of around 1 , 000 . the city s economy is dominated by agriculture , agriculture , and agriculture . there is a large number of industries in fictional and rural areas .
Sample 1:  bangalore has the best of the city s economy daily economy , with approximately 70 , 000 people making it the second largest economy in india . the city s economy is dominated by agriculture , agriculture , and agriculture . airways told its public - sector bacon , the largest industry in india and the largest industry in india . the city s economy is agriculture with a total economy of 96 . 39 of the city s total economy . the city s economy is a major industry in the country , since it is a major industry in the country and is a major industry in the torpedoes industry .
Sample 2:  bangalore has the best schools of the city , with a total of 4 , 000 students . the city s largest schools is the local library and the <UNK> library . the city s cultural and cultural centers are located in the city . other provincial parks include the <UNK> national park , the <UNK> national park , the <UNK> national park , the <UNK> national park , the <UNK> national historic park , the <UNK> national park and the <UNK> national park . the city s largest public park is the <UNK> park , which is puzzle park park . the city s largest bird park is the universal park park , which contains the largest marine park officer park in the world .
Sample 3:  bangalore has the best of the city s economy , with a total of 1 , <UNK> , <UNK> 79 . 7 million . the manufacturing sector is based in the city s western suburbs . there are large , large centre - facing system of retail and retail space , which helps provide a large number of retail stores . the city s economy is dominated by agriculture , agriculture , and agriculture . the city s economy is dominated by agriculture , agriculture , agriculture , and agriculture .
Sample 4:  bangalore has the best of the above - average football season , with 1 , <UNK> games played , and the lowest average of any season in the country . the city s highest average attendance total was 1 , 132 , 000 in the city s first season in the city s history . the city s lowest attendance total was the 1 , <UNK> , which was the highest for a season in the city s history . three of the lowest attendance figures for a football game were in the city s first season in the city , a record that was set by rule - based football league club sign - based football club , the copying affiliated club . the city s lowest attendance total was the 1 , <UNK> , 38 , 000 in september 2008 . the lowest attendance figures for a football league game were in the city s third win of the season , which was the garage nautical park s highest attendance death total . the lowest attendance opening - day attendance revenue of the year was 1 , <UNK> , 000 .


Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. “Attention is all you need.” In Advances in neural information processing systems, pp. 5998-6008. 2017.

Dai, Zihang, Zhilin Yang, Yiming Yang, William W. Cohen, Jaime Carbonell, Quoc V. Le, and Ruslan Salakhutdinov. “Transformer-xl: Attentive language models beyond a fixed-length context.” arXiv preprint arXiv:1901.02860 (2019).

Yang, Zhilin, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, and Quoc V. Le. “XLNet: Generalized Autoregressive Pretraining for Language Understanding.” arXiv preprint arXiv:1906.08237 (2019).