XLNet

XLNet: Generalized Autoregressive Pretraining for Language Understanding by Yang et al. was published in June 2019. The article claims that it overcomes shortcomings of BERT and achieves SOTA results in many NLP tasks.

In this article I explain XLNet and show the code of a binary classification example on the IMDB dataset. I compare the two model as I did the same classification with BERT (see here). For the complete code, see my github (here).

Introduction

I hope you remember BERT, as XLNet can be thought of as a model in the BERT family of NLP models. It builds on some ideas introduced by BERT, has a similar model size (allowing comparison) and uses BERT’s limitations as a starting point of what should be improved. Let’s see these limitations one by one then:

1) The BERT language model is trained by masking words in an input sequence and by predicting these masked tokens. This is great because it is bidirectional (or non-directional), however it creates a discrepancy between the pre-training and fine-tuning phases (as no tokens are masked during the latter).

2) This, in turn leads to an additional problem. The BERT model assumes that the masked tokens are conditionally independent one another. That is, predicting a masked token conditional on all other non-masked tokens would be the same as predicting the masked token conditional on all non-masked tokens and the remaining non-masked tokens.

Suppose your input sequence is:Screenshot 2019-10-07 at 10.47.49.png

“I went to [MASK] [MASK] last week.”

 

Can we say whether I went to “New York”/ “Los Angeles”/ “Las Vegas”/ “Kuala Lumpur” or simply “a restaurant” last week? No, as BERT assumes that the masked tokens are independent on another, so to predict “San” conditional on all unmasked tokens would essentially be the same as predicting conditional on the unmasked tokens and “Francisco”. We clearly see that it is an over-simplifying and limiting assumption.

3) BERT uses a fixed length input. Its maximum input length is 512 characters, however the model is often used with 128 or 256 sequence lengths. When I evaluated the performance of BERT on the IMDB binary classification, I observed that it performs much better (>93% accuracy) when the maximum sequence length is used compared to the reduced, 128 sequence length (~89% accuracy). This is a consequence of the nature of the data: reviews are much longer than 128 characters. When BERT chunks the first 128 characters of each review, it simply ignores the rest of it, loosing a lot of information.

This has another, negative consequence: instead of keeping maximum 128 characters that represents for instance one or two important sentences, we simply chunk the first 128 characters, leading to segment chunking.

But need not to worry, here comes XLNet to overcome these limitations. XLNet therefore keeps the idea of the bidirectional context, however, tries to encode it in a different way than BERT. It is the first bidirectional autoregressive method. Before we talk about XLNet, let’s review what is an autoregressive method and what is wrong with it.

Autoregressive language model 

Autoregressive language models build on the idea of estimating the probability distribution of a sequence of words/ tokens/ characters. The aim is to estimate the probability of the whole sequence (let’s assume a sentence). We can do that by breaking down the sentence into its elements and estimating the probability of each element,  conditional on some other elements.

Therefore, AR methods seek to estimate the probability distribution of the corpus. Let’s assume that our sequence x contains T words (Screenshot 2019-09-30 at 14.47.28). Then, AR models estimate the probability distribution of the sequence x by a forward or backward autoregressive language model (Screenshot 2019-09-30 at 14.55.02.pngrespectively). These models predict the word Screenshot 2019-09-30 at 14.57.50.png by choosing the word that maximizes the likelihood of the sequence.

Such AR models can thus be decomposed into individual, conditional probabilities, the probability of a token knowing previous tokens, or a probability of a token knowing the . tokens that follows it.

AR methods have several advantages. They intuitively make sense, they do not assume any independence assumption between tokens, they can be easily decomposed. But there is a problem with traditional AR methods: they are unidirectional. Let’s talk more about this point: where does the problem really lie?

Screenshot 2019-10-07 at 11.08.33.png

The problem is actually the sequential ordering of the elements (tokens or words) on the conditional side of the probability. That is, the problem is not P(word | something) but that “something” is always unidirectional and sequentially ordered. Ideally, we would have something like: P(word | all words around the predicted word). So let’s keep thinking about the same thing but approach it from a different side.

XLNet’s approach: Permutational Language Modelling 

XLNet comes to our rescue as it provides a clever way to make these conditional probabilities dependent not only on previous tokens but also on the tokens following the predicted words. This clever way is to maximize the expected log likelihood of a sequence w.r.t. all possible permutations of the factorization order. Hold on for a second, let’s see step by step what does this really mean.

Suppose we wish to predict one token of the following sequence (a quote by Miles Davis):

md.jpg

“Time isn’t the main thing. It is the only thing.”

I keep only the second sentence (for simplicity in the explanations and in the figures!)

“It is the only thing.”

After tokenization, the sequence becomes:

“it is the only thing”

Now let’s numerate all tokens: {1: ‘it’, 2: ‘is’, 3: ‘the’, 4: ‘only’, 5: ‘thing’}

“it is the only thing” = Screenshot 2019-09-30 at 15.02.46.png

Now suppose that we wish to predict the 3th token (‘the’). For a sequence of 5 words, there are 5! = 120 possible orderings of words.

One possible ordering is : Screenshot 2019-09-30 at 15.05.55, and so on. XLNet takes a random order from the 120 possible ones and builds an AR language model on the elements of the sequence but this time by predicting the second token of the factorization order conditional on the first element of the factorization order, the third element conditional on the first two elements of the factorization order, etc. The emphasis is on the factorization order since the model will determine the order of the conditional probabilities based on the factorization order and not on the natural, sequential order of words. The following two giffs shows the difference between the forward autoregressive (first figure) and permutational language modelling (second figure) approaches.

traditional_LM

permutational

The article refer to this factorization order as z and the set of all possible factorization orders as Z (z ∈ Z). Now let’s take some examples so that the approach becomes clearer. Suppose the randomly chosen ordering is [x3, x1, x2, x4, x5] = [the, it, is, only, thing]. Now the 3rd token is the first token of the factorization order and therefore the model can predict it by looking only at the memory from the previous segment (I’ll talk about this latter).

This situation correspond to the first figure (left side) in the picture below. After predicting the third token of the original sentence that is the first element of the factorization order  (“the”), the model predicts the second element of the factorization order (“it”, originally the first element of the sentence) conditional on the first element of the factorization order (“the”), etc.

Screenshot 2019-09-30 at 15.13.33.png

Now assume that the sampled factorization order is [x5, x2, x1, x3, x4] = [thing, is, it, the, only]. Now the 3rd word is the fourth element of the factorization order and so we can use three tokens, [x5, x2, x1] = [thing, is, it] for predicting it. Note, however, that the factorization order serves uniquely to determine the tokens that can be used to predict the desired token and not to represent their order: the order is sequential and the same for all factorization. Now we can see that the model is bi-directional: it predicts the 3rd token by using potentially the tokens that follow it. This is illustrated as the second figure (right-side) on the image above.

Finally, consider a factorization order of [x_5, x_2, x_3, x_4, x_1] = [thing, is, the, only, it]. Now the model can predict x3 by using the memory of the previous segment and the second and fifth tokens. This is the image below.

Screenshot 2019-09-30 at 11.29.54.png

I think you have the idea, but just to summary, keep in mind that:

  • The factorization order serves to identify the tokens that will be used to predict a token (x3 in the above examples). The tokens that precede the desired token in the factorization order will be employed to build an autoregressive language model.
  • The order of the tokens does not change and follow the sequential, natural ordering of the tokens. This is ensured by their positional encoding and I will talk about this later.
  • This approach is called permutation language modelling and in expectation, the model sees all possible permutation of the factorization order, that is the complete bi-directional context of the predicted token.

With the permutation language modelling explained above, the objective becomes:

Screenshot 2019-09-30 at 15.21.21

And in the case of our example, we can summarize the first objective as:

Screenshot 2019-10-07 at 11.31.44.png

where we search the parameters θ of the model that maximizes the likelihood of the predicted token, conditional on the tokens that precede xzt in the factorization order, z. (xzt represents the tth token in the factorization order z).

Therefore, the workflow of the model is as follows: for a text sequence x, it samples a factorization order (z) from the set of all possible factorization orders (Z). Next, it decomposes the likelihood of xt ( pθ(x) ) according to the factorization order (but by keeping the natural order of tokens). The model parameters θ is shared across all factorization orders during training and in expectation, xt sees every possible factorization orders thus all surrounding tokens. Thus, the model is able to capture bi-directional context.

Now this language modelling has some desirable properties, in particular it provides a solution for two weaknesses of the BERT. It does not assume independence assumption between tokens and it creates no discrepancy between the training and test phases. However, a naive implementation of the permutation language model objective with the Transformer architecture does not work. The next section explain why.

Problems with a naive implementation of the Transformers architecture with the permutation language modelling objective:

Let’s assume that we wish to predict the second token (x2) of two factorization orders : [”thing”, “only”, …] and [“thing”, “it”, …]. The conditional probabilities of these two tokens are:

  • P(only | thing ) = P( x | thing )
  • P(   it   | thing)  = P( x | thing )

These probabilities give essentially the same distribution!

It is clear that the above way of computing the probability of a token conditional on the precedent tokens in the factorization order neglects an important information: the position of the tokens. Thus, what we would like to have is something likeScreenshot 2019-10-07 at 11.53.19.png:

  • P( only | 4, {thing:5})
  • P(   it    | 1, {thing:5})

You see this would be much better, however, this does not exist in the Transformer architecture. What does exist is the following:

  • P(only | {only:4}, {thing:5})
  • P( it |   {it: 1},  {thing:5})

 

But then the learning objective becomes trivial and the model will not learn anything. XLNet provides a way to construct the objective P( only | 4, {thing:5}) and P(   it    | 1, {thing:5}). This way is the Two-Stream Attention. 

Two-Stream Attention

Screenshot 2019-10-07 at 11.56.32.pngIntuitively, the two-stream mechanism computes two attention vectors for each token in each position i:

  • The content stream that contains information about the content and position of token i in factorization order and  the content and the position of all token before token in the factorization order. This is essentially the same as a normal attention mechanism.
  • The query stream that contains information for token in the factorization order z about the content and position of all precedent token in and the position of token i. This is different than a normal attention mechanism as we do not include any information about the content at position i.

Finally, we use the last layer’s query representation to compute:

Screenshot 2019-10-07 at 12.03.34.png

It might be useful to see this on a schema. Suppose our factorization order is [“thing”:5, “is”:2, “only”:4, “the”:3, “it”:1] and we want to predict the third token of the factorization order, “only” (the 4th token of the natural order).

First, we initialize the content vectors with the embedding and the positional encoding while we initialize the query vector with a learned vector, w. This is the same for each token and thus cannot provide information about the embedding of a token.

Next, we need to compute the content stream of the first token of the factorization order, “thing”. This content stream includes information about the position and the content. Next, compute the content stream of “is”. This content stream includes information about the position and content of “is” and “thing”, as “thing” precede “is” in the factorization order. Both these content streams are initialized with the embedding of words. Finally, when computing the query stream of “only”, we only include information about the position. The query stream is initialized with a vector, w that is the same for all tokens.

 

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: