Introduction¶

Bor-bor-zan is a famous game played at my high school bdyz. I won't explain the rules here for simplicity, but note that the game has eight basic moves, and the two players produce their moves at the same time. The game can end when a deadly move strikes one player, which can happen on the first move, or after an arbitrary number of moves. I want to train a model to recognize the game's rules and predict the two players' next moves.

Preparing the data¶

Defining the rules¶

I've created a class that can fully represent the game of bor-bor-zan.

In [2]:
import random
class Game:
    def __init__(self):
        self.winner = 0
        self.e1 = 0 # energy of player 1, etc.
        self.e2 = 0
        self.over = False # the game is still on
        self.requires = [0, 0, 0, 1, 1, 0, 1, 3, 2]
        self.text = [0, "zan", "shield", "shoot", "reflect", "buddha", "steal", "chick", "big-shield"]
        self.record = []
    def act(self, action1, action2):
        # 1:攒豆;2:盾;3:枪;4:鄙视;5:拜佛;6:爱心;7:鸡!8:大盾
        # determine if the game is still on
        if self.determine(action1, action2):
            # player 1 wins
            self.over = True
            self.winner = 1
        elif self.determine(action2, action1):
            self.over = True
            self.winner = 2
        if action1 == 1:
            self.e1 += 1
        elif action1 in [3, 4, 6]:
            self.e1 -= 1
        elif action1 == 7:
            self.e1 -= 3
        elif action1 == 8:
            self.e1 -= 2
        if action2 == 1:
            self.e2 += 1
        elif action2 in [3, 4, 6]:
            self.e2 -= 1
        elif action2 == 7:
            self.e2 -= 3
        elif action2 == 8:
            self.e2 -= 2
        # 特判拜佛、爱心
        # 如果有一方拜佛而且没有死,那么对方豆肯定被清空。对拜同理。
        if action1 == 5:
            self.e2 = 0
        if action2 == 5:
            self.e1 = 0
        # 如果双方同时出爱心,则对撞,不产生效果
        if action1==6 and action2==6:
            return
        if action1==6 and action2!=5:
            self.e1+=self.e2
            self.e2 = 0
        if action2==6 and action1!=5:
            self.e2+=self.e1
            self.e1 = 0
        self.record.append((self.text[action1], self.text[action2]))
    def determine(self, act1, act2):
        # 不分顺序,看看1是否能将2打死
        if act1 == 2 and act2 == 5: return True # 拜佛拜死
        if act1 == 3 and (act2 in [1, 5, 6, 8]): return True # 开枪打死
        if act1 == 4 and act2 == 3: return True # 反弹
        if act1 == 5 and act2 == 4: return True # 鄙视佛祖
        if act1 == 7 and act2 <= 6: return True # 鸡无敌
        if act1 == 8 and act2 == 5: return True # 大盾秀
        return False
    def pretty_record(self):
        return list((self.text[a], self.text[b]) for (a,b) in self.record)

To show usage of the class I defined above, here's a demonstration:

In [3]:
game = Game()
buddha, shield = game.text.index('buddha'), game.text.index('shield')
game.act(buddha, shield)
print(game.over, game.winner)
True 2

Basically, after creating a Game instance, you can use the act method to pass in the moves each player had chosen. After act was called, the over and winner properties will be updated to show whether the game has ended, and the winner of the game. But beware that this class doesn't check for illegal moves. To fix it, we define another class that inherits from Game, creating a basic program that generates random outputs in a possible situation.

In [4]:
class GameAI(Game):
    def __init__(self):
        super().__init__()
        # 此时机器作为player2
    def act(self, action1):
        actions_available = [i for i in range(1, 9) if self.requires[i]<=self.e2]
        act = random.choice(actions_available)   
        super().act(action1, act)
        return act

Now let's test it to see which move it produces:

In [5]:
game = GameAI()
game.text[game.act(game.text.index('zan'))]
Out[5]:
'shield'

So the supposed AI is working properly, sticking to the game's rules. It's time to generate some game data.

Generating game data¶

You might have noticed that the GameAI only generates a possible action for player 2. To fix that, we will need an algorithm to generate possible actions for both two players. So I defined a new class, DataGen for this purpose.

In [6]:
class DataGen(Game):
    def __init__(self):
        super().__init__()
        self.text_record = []
        
    def act(self):
        actions_available_1 = [i for i in range(1, 9) if self.requires[i]<=self.e1]
        actions_available_2 = [i for i in range(1, 9) if self.requires[i]<=self.e2]
        act_combination = random.choice(actions_available_1), random.choice(actions_available_2)
        super().act(*act_combination)
        text = self.text[act_combination[0]], self.text[act_combination[1]]
        self.text_record.append(text)
        return text

Test it:

In [7]:
game = DataGen()
while not game.over:
    game.act()
print(game.text_record)
[('shield', 'buddha')]

Now lets generate 100,000 game records:

In [8]:
def get_one_rec():
    game = DataGen()
    while not game.over:
        game.act()
    return game.text_record

data = []
for i in range(100000):
    data.append(get_one_rec())
In [9]:
data[:5]
Out[9]:
[[('zan', 'buddha'),
  ('zan', 'buddha'),
  ('buddha', 'zan'),
  ('shield', 'shield'),
  ('buddha', 'buddha'),
  ('buddha', 'zan'),
  ('buddha', 'shield')],
 [('buddha', 'buddha'), ('shield', 'shield'), ('buddha', 'shield')],
 [('shield', 'buddha')],
 [('shield', 'buddha')],
 [('buddha', 'buddha'), ('buddha', 'shield')]]

It takes a really short time to generate all these data! We now need to concatenate this list to a document, for feeding into our language model:

In [10]:
def lst_to_doc(lst):
    stream = ''
    for each in lst:
        stream += ' to '.join(each)+' , '
    stream += 'end . '
    return stream
In [11]:
lst_to_doc([('buddha', 'zan'), ('buddha', 'zan'), ('zan', 'zan'), ('buddha', 'shoot')])
Out[11]:
'buddha to zan , buddha to zan , zan to zan , buddha to shoot , end . '

We can therefore map this function to all the items in our 100,000 items long list:

In [12]:
documents = []
for each in data:
    documents.append(lst_to_doc(each))
In [13]:
documents[:3]
Out[13]:
['zan to buddha , zan to buddha , buddha to zan , shield to shield , buddha to buddha , buddha to zan , buddha to shield , end . ',
 'buddha to buddha , shield to shield , buddha to shield , end . ',
 'shield to buddha , end . ']

Splitting the dataset and joining the documents together:

In [14]:
train_document = ''.join(documents[:80000])
valid_document = ''.join(documents[80000:])
In [15]:
train_document[900: 1000]
Out[15]:
'eld to shield , zan to buddha , buddha to buddha , zan to buddha , shield to shield , buddha to shie'

Seems right!

Tokenization and Numericalization¶

First we import the libraries:

In [16]:
from fastbook import *
from fastai.text.all import *

Then we split the document and turn that into corresponding indexes:

In [17]:
tokens = train_document.split(' ')
tokens[:10]
Out[17]:
['zan', 'to', 'buddha', ',', 'zan', 'to', 'buddha', ',', 'buddha', 'to']
In [18]:
vocab = L(*tokens).unique()
vocab
Out[18]:
(#13) ['zan','to','buddha',',','shield','end','.','shoot','steal','reflect','chick','big-shield','']
In [19]:
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
nums
Out[19]:
(#1435321) [0,1,2,3,0,1,2,3,2,1,0,3,4,1,4,3,2,1,2,3...]

Now do the same thing with the validation document:

In [20]:
tokens_v = valid_document.split(' ')
nums_v = L(word2idx[i] for i in tokens_v)

Creating the Dataloaders¶

We want the model to predict the next token based on the previous three tokens. I know that's clearly not enough, but for simplicity of code, we'll have to temporarily endure this inconvenience.

In [105]:
seqs = L((tensor(nums[i:i+6]), nums[i+6]) for i in range(0,len(nums)-7,6))
seqs
Out[105]:
(#239219) [(tensor([0, 1, 2, 3, 0, 1]), 2),(tensor([2, 3, 2, 1, 0, 3]), 4),(tensor([4, 1, 4, 3, 2, 1]), 2),(tensor([2, 3, 2, 1, 0, 3]), 2),(tensor([2, 1, 4, 3, 5, 6]), 2),(tensor([2, 1, 2, 3, 4, 1]), 4),(tensor([4, 3, 2, 1, 4, 3]), 5),(tensor([5, 6, 4, 1, 2, 3]), 5),(tensor([5, 6, 4, 1, 2, 3]), 5),(tensor([5, 6, 2, 1, 2, 3]), 2),(tensor([2, 1, 4, 3, 5, 6]), 0),(tensor([0, 1, 0, 3, 2, 1]), 7),(tensor([7, 3, 5, 6, 2, 1]), 2),(tensor([2, 3, 0, 1, 0, 3]), 8),(tensor([8, 1, 8, 3, 2, 1]), 4),(tensor([4, 3, 5, 6, 2, 1]), 0),(tensor([0, 3, 2, 1, 2, 3]), 0),(tensor([0, 1, 2, 3, 0, 1]), 4),(tensor([4, 3, 0, 1, 0, 3]), 2),(tensor([2, 1, 7, 3, 5, 6]), 0)...]

The same with the validation set:

In [106]:
seqs_v = L((tensor(nums_v[i:i+6]), nums_v[i+6]) for i in range(0,len(nums_v)-7,6))
In [84]:
bs = 128
dls = DataLoaders.from_dsets(seqs, seqs_v, bs=128, shuffle=False)

A simple recurrent neural network¶

In [24]:
class LMModel1(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        
    def forward(self, x):
        h = F.relu(self.h_h(self.i_h(x[:,0])))
        h = h + self.i_h(x[:,1])
        h = F.relu(self.h_h(h))
        h = h + self.i_h(x[:,2])
        h = F.relu(self.h_h(h))
        return self.h_o(h)

Let me explain what the code above means. i_h means input to hidden layer, which is specified as an embedding layer - that basically does an accelerated indexing process. h_h means hidden to hidden layer, adding the nonlinearality to improve the model's performance. h_o is the final layer, going from the hidden layer to the final output. The forward in pytorch takes in a batch of input data passed in by x. So the shape of x should be (batch_size, 3) because we only allowed the model to process three words. At least, in this simpler case.

In a recurrent neural network, there is a subtle difference in how we produce the activations. It's a linear process, but done with a loop: As we can see from the code, the first word of each batch was processed first. We'll break up the process step by step:

  • x[:, 0] selects the first word from every batch, making the tensor into the shape (batch_size, 1).
  • self.i_h(x[:,0]) uses the embedding matrix to replace the numericalized tokens into corresponding vectors, so the shape becomes (batch_size, n_hidden). n_hidden is just another representation of the latent factors used in collaborative filtering.
  • h = F.relu(self.h_h(self.i_h(x[:,0]))) turns the matrix into shape (n_hidden, n_hidden) and added relu for non-linearality.
  • then we add the activations of the second, third word in each of the batches, inserting relu between.
  • finally, self.h_o(h) is called to produce the output matrix.

Overall, we use a weight matrix that is universal to all three words, but the final activations before a word are influenced by the words before it, in this way the model can take into account the order of the words in a sentence, not by adding the vectors of the word embeddings together.

For the next piece of code, because the final activations are just predictions for each of the tokens, it can be treated the same way as image classification. So cross_entropy is used.

In [31]:
learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, 
                metrics=accuracy)
learn.fit_one_cycle(4, 1e-3)
epoch train_loss valid_loss accuracy time
0 1.440551 1.438455 0.297314 00:04
1 1.426698 1.425912 0.301389 00:04
2 1.420663 1.420298 0.302786 00:04
3 1.418132 1.419128 0.302586 00:04
In [27]:
n,counts = 0,torch.zeros(len(vocab))
for x,y in dls.valid:
    n += y.shape[0]
    for i in range_of(vocab): counts[i] += (y==i).long().sum()
idx = torch.argmax(counts)
idx, vocab[idx.item()], counts[idx].item()/n
Out[27]:
(tensor(2), 'buddha', 0.2601081081081081)

As we can see, the most common word is 'to', accounting for 22% of the total input tokens. So our result is way better!

Here the logic of this piece of code needs to be further explained. We iterate through the dls.valid dataset, which gives us one batch at a time. y is a vector that contains 128 elements, corresponding to our batch_size. Then it is easy to understand that this is just a simple counting algorithm.

Using a loop¶

Now, instead of add the activations and applying ReLU by hand, we use a loop to automate the process.

In [30]:
class LMModel2(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        
    def forward(self, x):
        h = 0
        for i in range(6):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
        return self.h_o(h)
learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, 
                metrics=accuracy)
learn.fit_one_cycle(4, 1e-3)
epoch train_loss valid_loss accuracy time
0 1.212522 1.203242 0.371975 00:05
1 1.195914 1.188813 0.369530 00:05
2 1.189458 1.183382 0.368815 00:05
3 1.186409 1.182275 0.372624 00:05

Adding detach¶

A problem with our current model is that the hidden state h get set to zero every time it receives new outputs. This way, the neural network can only see the weird fragments of the real document, which is not good for our training. Now we set h to be a global variable in the class, this way the model can remember what it has seen in the past.

However this creates another problem. The method we just decided to use makes our model really deep, and for each of these layers(there might be 10000 of them) we need to calculate the derivatives all from the start. The memory would explode! So we use the detach method in pytorch to only calculate the last three layers' derivatives.

In [32]:
class LMModel3(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        self.h = 0
        
    def forward(self, x):
        for i in range(6):
            self.h = self.h + self.i_h(x[:,i])
            self.h = F.relu(self.h_h(self.h))
        out = self.h_o(self.h)
        self.h = self.h.detach()
        return out
    
    def reset(self): self.h = 0

It is not an easy task to fully understand what detach is doing here. I will try my best to explain what's actually going on:

  • We do the regular forward propagation in the forward method. Then we called self.h = self.h.detach().
  • Remember from the basic rules of deep learning, pytorch will automatically track the forward process and the calculation of the loss, forming a computation graph to get the derivatives for each parameter. If we don't cut the graph and allow the model to do millions of calculations to the hidden state h, then the computation graph would be extremely long, breaking the whole program.
  • So the graph should be cut into several small parts. For example, if we have 1000 words in total that need to be processed, without adding the detach(), then the last word will need to calculate 1000 layers of gradients. But what if after each time forward is called, we 'cut' the graph, so the process of calculating derivatives will stop there in future calculations? That's what detach() is doing!
  • When detach() is called, the previous gradients will not be affected. It will only influence the backpropagations in the future, calculated by the next forward pass! So each time we do backward(), the gradients will only be calculated along the last six layers, increasing the efficiency of the model.
In [107]:
dls = DataLoaders.from_dsets(
    seqs, seqs_v, 
    bs=128, drop_last=True, shuffle=False)
In [108]:
learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy,
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(4, 0.01)
epoch train_loss valid_loss accuracy time
0 1.216328 1.203639 0.371385 00:06
1 1.202349 1.197370 0.371319 00:06
2 1.191182 1.185475 0.371069 00:06
3 1.187282 1.183284 0.372052 00:06

Note that there is no significant performance boost. Maybe only giving six tokens for our model to predict on is not enough.

Creating more signal¶

We give eight tokens to the model for it to predict (or an arbitrary number bigger than the original one is acceptable). But first we need to introduce another method for data preparation: after adding the detach() method to our toolbox, the natural order of the sequence matters! And a big problem occurs with our current dataset: the first 128 items will be put in a batch, but they belong to the same part of the document, making them being processed by the neural network in parallel, thereby losing the order!

So we need a way to transpose our sequence - making the model read them one by one. That's where group_chunks come in.

In [92]:
def group_chunks(ds, bs):
    m = len(ds) // bs
    new_ds = L()
    for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
    return new_ds
In [93]:
group_chunks(list(range(1, 1001)), 100)
Out[93]:
(#1000) [1,11,21,31,41,51,61,71,81,91,101,111,121,131,141,151,161,171,181,191...]

I'm using a simpler version of the input data to demonstrate. Suppose we have a sequence [1, 2, ..., 1000] that have a natural order, needing to be split to ten batches (which mean batch_size = 100). We don't want to split them into [1, 2, ..., 100], [101, 102, ..., 200] and so forth, because this would lose useful information for our model. When we rearrange the list using group_chunks, the list is turned to [1, 11, 21, ..., 991, 2, 12, ...]. When processed as a batch, the natural order will be maintained.

When using the previous models, the group_chunks method doen't add much performance boost. But now, it is essential.

Now we no longer need to use drop_last=True because the datasets are already normalized to a correct size for batching.

In [113]:
sl = 6
seqs_mult = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
         for i in range(0,len(nums)-sl-1,sl))
seqs_mult_v = L((tensor(nums_v[i:i+sl]), tensor(nums_v[i+1:i+sl+1]))
         for i in range(0,len(nums_v)-sl-1,sl))
dls = DataLoaders.from_dsets(group_chunks(seqs_mult, bs),
                             group_chunks(seqs_mult_v, bs),
                             bs=bs, shuffle=False)

Also worth mentioning is that we did something different to the sequences. Not only we want the model to predict the next word after every three words, but to predict the next three words altogether. This requires a big change in our dataset and model, but the result is worth it. By giving more signal to our model, it has mores things to learn on, therefore getting better performance.

In [129]:
class LMModel4(Module):
    def __init__(self, vocab_sz, n_hidden):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)  
        self.h_h = nn.Linear(n_hidden, n_hidden)     
        self.h_o = nn.Linear(n_hidden,vocab_sz)
        self.h = 0
        
    def forward(self, x):
        outs = []
        for i in range(sl):
            self.h = self.h + self.i_h(x[:,i])
            self.h = F.relu(self.h_h(self.h))
            outs.append(self.h_o(self.h))
        self.h = self.h.detach()
        return torch.stack(outs, dim=1)
    
    def reset(self): self.h = 0

There's a slight change to the forward function too. We store the hidden states during each iteration and force them to make a corresponding prediction, storing the total predictions in a list of length sl. Therefore the list outs after being stacked on dim=1 becomes a tensor having a shape (bs, sl, vocab_size), and to do cross entropy on a higher dimension needs a modified version of the loss function. Basically we flatten the tensor first, and then call the default F.cross_entropy.

In [132]:
def loss_func(inp, targ):
    return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))
In [133]:
learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func,
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(5, 3e-3)
epoch train_loss valid_loss accuracy time
0 0.585487 0.591583 0.683835 00:07
1 0.572355 0.574745 0.685032 00:07
2 0.569663 0.570365 0.684243 00:07
3 0.567611 0.569115 0.684288 00:07
4 0.567054 0.568130 0.684638 00:07

This is a surprising accuracy!

Trying a multilayered RNN¶

What we've actually changed here is using pytorch's default RNN framework, which basically does the same thing as our previous code, but allowing us to increase the depth of our RNN.

In [136]:
class LMModel5(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = torch.zeros(n_layers, bs, n_hidden)
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(res)
    
    def reset(self): self.h.zero_()
learn = Learner(dls, LMModel5(len(vocab), 64, 3), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(5, 3e-3)
epoch train_loss valid_loss accuracy time
0 0.578014 0.584330 0.684266 00:12
1 0.574130 0.577295 0.685140 00:11
2 0.569712 0.574084 0.685457 00:12
3 0.568163 0.571038 0.685151 00:11
4 0.567217 0.569991 0.684593 00:12

Sadly, this doesn't bring much difference to our model.

Using LSTMs¶

It's time to introduce a powerful tool: long-short term memeory! 1750573234874.png

First, the arrows for input and old hidden state are joined together. In the RNN we wrote earlier in this chapter, we were adding them together. In the LSTM, we stack them in one big tensor. This means the dimension of our embeddings (which is the dimension of $x_{t}$) can be different than the dimension of our hidden state. If we call those n_in and n_hid, the arrow at the bottom is of size n_in + n_hid; thus all the neural nets (orange boxes) are linear layers with n_in + n_hid inputs and n_hid outputs.

The first gate (looking from left to right) is called the forget gate. Since it’s a linear layer followed by a sigmoid, its output will consist of scalars between 0 and 1. We multiply this result by the cell state to determine which information to keep and which to throw away: values closer to 0 are discarded and values closer to 1 are kept. This gives the LSTM the ability to forget things about its long-term state. For instance, when crossing a period or an xxbos token, we would expect to it to (have learned to) reset its cell state.

The second gate is called the input gate. It works with the third gate (which doesn't really have a name but is sometimes called the cell gate) to update the cell state. For instance, we may see a new gender pronoun, in which case we'll need to replace the information about gender that the forget gate removed. Similar to the forget gate, the input gate decides which elements of the cell state to update (values close to 1) or not (values close to 0). The third gate determines what those updated values are, in the range of –1 to 1 (thanks to the tanh function). The result is then added to the cell state.

The last gate is the output gate. It determines which information from the cell state to use to generate the output. The cell state goes through a tanh before being combined with the sigmoid output from the output gate, and the result is the new hidden state.

Here's a simple implementation of the LSTM cell, which I'll thoroughly explain:

In [143]:
class LSTMCell1(Module):
    def __init__(self, ni, nh):
        self.forget_gate = nn.Linear(ni + nh, nh)
        self.input_gate  = nn.Linear(ni + nh, nh)
        self.cell_gate   = nn.Linear(ni + nh, nh)
        self.output_gate = nn.Linear(ni + nh, nh)

    def forward(self, input, state):
        h,c = state
        h = torch.cat([h, input], dim=1)
        forget = torch.sigmoid(self.forget_gate(h))
        c = c * forget
        inp = torch.sigmoid(self.input_gate(h))
        cell = torch.tanh(self.cell_gate(h))
        c = c + inp * cell
        out = torch.sigmoid(self.output_gate(h))
        h = out * torch.tanh(c)
        return h, (h,c)
  • The inputs and states are stored or passed in separately, so we use h,c = state.
  • Following the instructions, we stack the hiddden state and inputs together.
In [137]:
class LSTMCell(Module):
    def __init__(self, ni, nh):
        self.ih = nn.Linear(ni,4*nh)
        self.hh = nn.Linear(nh,4*nh)

    def forward(self, input, state):
        h,c = state
        # One big multiplication for all the gates is better than 4 smaller ones
        gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
        ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])
        cellgate = gates[3].tanh()

        c = (forgetgate*c) + (ingate*cellgate)
        h = outgate * c.tanh()
        return h, (h,c)
class LMModel6(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(res)
    
    def reset(self): 
        for h in self.h: h.zero_()
In [138]:
learn = Learner(dls, LMModel6(len(vocab), 128, 2), 
                loss_func=CrossEntropyLossFlat(), 
                metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(5, 1e-2)
epoch train_loss valid_loss accuracy time
0 0.575987 0.578387 0.685143 00:20
1 0.571813 0.574853 0.684710 00:20
2 0.569074 0.570702 0.685215 00:21
3 0.567799 0.569034 0.684810 00:20
4 0.567116 0.568250 0.684554 00:20

Weight-tied Regularized LSTM¶

In [139]:
class LMModel7(Module):
    def __init__(self, vocab_sz, n_hidden, n_layers, p):
        self.i_h = nn.Embedding(vocab_sz, n_hidden)
        self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
        self.drop = nn.Dropout(p)
        self.h_o = nn.Linear(n_hidden, vocab_sz)
        self.h_o.weight = self.i_h.weight
        self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
        
    def forward(self, x):
        raw,h = self.rnn(self.i_h(x), self.h)
        out = self.drop(raw)
        self.h = [h_.detach() for h_ in h]
        return self.h_o(out),raw,out
    
    def reset(self): 
        for h in self.h: h.zero_()
In [140]:
learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),
                    loss_func=CrossEntropyLossFlat(), metrics=accuracy)
In [142]:
learn.fit_one_cycle(5, 1e-2, wd=0.1)
epoch train_loss valid_loss accuracy time
0 0.577216 0.579187 0.684898 00:13
1 0.573831 0.572951 0.684896 00:13
2 0.570968 0.571168 0.685196 00:13
3 0.568685 0.569038 0.684913 00:13
4 0.567818 0.568213 0.684796 00:13