# Regex classification with monte-carlo tree search

Last time we used our regex classifier to start classifying our ingredient quantities. Unfortunately we were hampered by the need to join rules together as the only way of enforcing an ‘or’ relationship. It would be nice if we could just provide the list of rules and let the classifier figure out which ones to run - and in what order.

## Deducing the rule sequence

To get from our input string to a suitably classified output we need to apply one or more of the available rules. Even with the small number of rules we have so far there are many subsets we could choose and as rule operations are not generally commutative there could be many orderings that we can apply. Therefore, how can we select a suitable sequence of rules to achieve the required result?

One approach is to use a technique called monte-carlo tree search. The idea is simple: try a random sequence of moves and build a tree recording the success of each branch then weight the selection of subsequent random moves based on previous success. We repeat this many times and then finally we walk the path with the most success, applying the rules for that path to produce the desired output.

## What is success?

For this approach to work we need to define exactly what we mean by success and failure. In our case, we are looking to extract certain information from the output string so a successful path will result in that information being present. To test for this we can use a regex. Conversely, if we reach a certain depth without matching our regex then we can write off that path as a loss.

## Bad move

From any point in our moves tree, our next move could potentially be to apply any one of our rules. However, a lot of those moves are useless or malign. For example, if applying a particular rule has no effect on the output string then there is no point making that move. Similarly, if applying a rule classifies part of the string as X when it was already classified as X then it is pointless and possibly downright unhelpful. We should decree such moves to be illegal because not only are they unhelpful in completing the classification but they eat up resources that could be better spent discovering other parts of the move tree.

## Keep it sparse

If we build the entire tree at the outset then we’ll soak up a ton of memory. A better option is to only store the branches that we have followed. This works well because of the way the algorithm favours successful routes, which means the probability of visiting large parts of the tree (those along successive unsuccessful branches) becomes very small as the depth increases and the branch count rises exponentially.

## The move tree

Let’s think about what a node in our move tree needs. We’ll start with an empty class:

```
class explored_move:
pass
```

We need a collection of child nodes representing the next moves that we have explored from the current position. As our tree is sparse, the obvious container to use here is a dictionary that maps from a unique move id to the next move node - we can populate this dictionary as we explore.

```
class explored_move:
def __init__(self):
self.next_moves = {}
```

We can record success/failure by storing the ratio of the number of times we have won (matched the target pattern) after playing the move divided by the number of times we have played the move. Let’s keep the numerator and denominator as two integers so we can simply increment them. We’ll increment the denominator (number of plays) when we visit a node and the numerator (number of wins) when the target pattern is matched:

```
class explored_move:
initial_wins = 1
initial_plays = 2
def __init__(self):
self.next_moves = {}
self.wins = explored_move.initial_wins
self.plays = explored_move.initial_plays
```

Notice that we have initialised plays to 2 and wins to 1. Why have we done this? We’ll come to this later but essentially, we need the win ratio to be non-zero or this move will never have a chance of being selected.

The target pattern could be matched several moves further down the tree from here so we’ll also need a link back to our parent node (the previous move) so that we can propagate our wins back up the tree:

```
class explored_move:
initial_wins = 1
initial_plays = 2
def __init__(self, prev_move=None):
self.prev_move = prev_move
self.next_moves = {}
self.wins = explored_move.initial_wins
self.plays = explored_move.initial_plays
```

The previous move is provided as a constructor parameter. For the root node there is no previous move so we’ll default the parameter to `None`

so it can be omitted in this case.

Let’s add some helper methods to update our ratio. First we’ll need to record when we play a move. This is simple - we just increment the play count:

```
def play(self):
self.plays += 1
```

Next we want to record if we find an illegal move. We do this by setting the win count to zero, which will guarantee that move is never selected again:

```
def illegal(self):
self.wins = 0
```

Finally, we want to record a win. This is only discovered at a leaf node in the tree and the win has to be propagated back up through all the nodes in the current path up to the root. We do this by calling `win`

recursively until `prev_node is None`

:

```
def win(self):
self.plays += 1
self.wins += 1
if self.prev_move is not None:
self.prev_move.win()
```

## Next move selection

Now we’ve got the basics in, let’s consider how to select a random move. If we were selecting 1 of n equally weighted moves then we could simply generate a random number from 0 to n - 1 but we want to weight our moves based on previous success. If we think of the unweighted case as a sequence of intervals from n to n + 1 representing each of the n possible moves, we can see that each interval has length 1 and a random number in the range 0 to n - 1 has an equal chance of landing in any of those intervals. If we double the length of just one of those intervals so it has length 2 and pick a random number from 0 to n, that interval has twice the probability of being hit compared to the others. It’s clear that we can multiply the length of each range by a suitable weighting factor and adjust the upper bound of the random number we generate accordingly to get a weighted probability of hitting each interval that is proportional to each move’s win ratio.

The win ratio is the number of wins divided by the number of plays for a given move. Earlier, we set the initial number of wins to 1 and number of plays to 2. That gives a win ratio of 0.5. This is a good starting point because it gives room to move up or down following a win or loss. If we started with a ratio of 1 then the explored wins could never out-perform unexplored branches. Conversely if we started with a ratio of zero then no move would be possible as the probability of hitting any given move would also be zero. A ratio of 0.5 is a nice middle ground between those extremes.

We could simply accumulate the floating point win ratios but if we multiply through first by a common denominator we can avoid floating point operations. The common denominator is calculated by multiplying together all the play counts from moves with non-zero win counts:

```
d = 1
for i in range(self.num_rules):
if self.num_wins(i) > 0:
d *= self.num_plays(i)
```

We can then calculate the upper limit for the random index into our intervals by accumulating common denominator * wins / plays:

```
max_i = 0
for m in range(self.num_rules):
max_i += self.num_wins(m) * (d // self.num_plays(m))
```

Now it’s simply a case of generating a random number in the interval `0`

to `max_i - 1`

and accumulating intervals until we reach the interval containing the random number. The index of that interval is the index of our next move:

```
j = random.randrange(0, max_i)
i = 0
for m in range(self.num_rules):
i += self.num_wins(m) * (d // self.num_plays(m))
if i > j:
return m
```

Putting it all together, our complete `choose_next_move`

method is:

```
def choose_next_move(self):
d = 1
for m in range(self.num_rules):
if self.num_wins(m) > 0:
d *= self.num_plays(m)
max_i = 0
for m in range(self.num_rules):
max_i += self.num_wins(m) * (d // self.num_plays(m))
if max_i == 0:
return -1
j = random.randrange(0, max_i)
i = 0
for m in range(self.num_rules):
i += self.num_wins(m) * (d // self.num_plays(m))
if i > j:
return m
raise Exception('Unreachable: ' + str(i) + ' > ' + str(max_i))
```

Notice we’ve added a check that `max_i`

is non-zero. A zero `max_i`

means there is a zero probability of selecting any move so we return -1, meaning ‘no move possible’. We also raise an exception at the end as a sanity check - this should be unreachable if the preceding code is correct.

We call `num_wins()`

and `num_plays()`

to return the number of wins and the number of plays for the next move. Why the next move and not the current move? Well we’d like to be able to determine the success ratio for all the possible next moves not just explored moves. These helpers hide the sparse nature of the tree by returning default values for the unexplored branches:

```
def num_wins(self, i):
return self.next_moves[i].wins if i in self.next_moves else explored_move.initial_wins
def num_plays(self, i):
return self.next_moves[i].plays if i in self.next_moves else explored_move.initial_plays
```

## Time to play

Now we can select a next move, we want the ability to play that move. Playing the move involves finding the explored move if it already exists or creating it and linking it if it doesn’t and then updating its play count:

```
def play_next(self, rule_idx):
if rule_idx not in self.next_moves:
self.next_moves[rule_idx] = explored_move(self.num_rules, self)
next_move = self.next_moves[rule_idx]
next_move.play()
return next_move
```

Now we have everything we need to play sequences of moves, building up a tree of explored moves, let’s create a function that tries to play a randomly selected legal move and returns the move to play from next along with the current string for that move.

We’ll call `choose_next_move`

to select a random move that may be legal or not. If no move is possible then we return `None, None`

to indicate this path is a dead end. Otherwise, we make the move by applying its rule to the string. If the string has changed and there is not a re-classification or empty classification then we return the move that was made with the updated string. Otherwise we mark the move as illegal and return its previous move along with the original string. Here’s how our complete function looks:

```
def try_play_move(rules, cur_move, s):
# try to select a move
r = cur_move.choose_next_move()
# if no valid move return Nones
if r == -1:
return None, None
# make the move
next_move = cur_move.play_next(r)
s_next = re_sub(rules[r].p, rules[r].s, s)
# if it was legal return next move and string
if s_next != s and len(re.findall('\<(?P<type>[a-z_]+)\>(\s*\<[a-z_]+\>)*\s*\<(?P=type)\>|\<(?P<type2>[a-z_]+)\>\s*\<\/(?P=type2)\>', s_next)) == 0:
return next_move, s_next
# otherwise, mark illegal and return current
next_move.illegal()
return cur_move, s
```

## Following multiple paths

Now we can play legal moves along a path, we need to run multiple such paths to build up the tree.

For each path, we’ll start with the original input string and the root of the move tree and make moves up to a maximum depth. If we reach a point where there are no possible moves left then we’ll abandon that path. If the result of making a move causes the string to match the desired output pattern then we record a win for that path and move on to the next one. The algorithm looks like this:

```
for p in range(paths):
s = sin
cur_move = root
for d in range(max_depth):
next_move = cur_move
while next_move == cur_move:
next_move, s_next = try_play_move(rules, cur_move, s)
if next_move is None:
break
cur_move = next_move
s = s_next
if re_match(pout, s):
cur_move.win()
break
```

## Finding the best route

Now we have built our moves tree we can traverse it looking for the route that has the most wins at each stage and along the way and apply those moves to the input string to produce the desired output.

Firstly, we’ll add a method to the `explored_move`

class to get the best move from a particular point in the tree. This simply involves comparing the win ratios for each of the explored next moves. We multiply through by the common denominator as before to avoid floating point operations:

```
def best_move(self):
# calculate a common denominator
d = 1
for i in range(self.num_rules):
if self.num_wins(i) > 0:
d *= self.num_plays(i)
# start with ratio zero and no valid move
best_ratio = 0
best_m = (-1, None)
# for each move if the ratio is better store it as the new best along with the best move
for item in self.next_moves.items():
m, node = item
ratio = node.wins * (d // node.plays)
if ratio > best_ratio:
best_ratio = ratio
best_m = (m, node)
# return the best move
return best_m
```

We can now use this to walk down the tree until we reach a win:

```
r, node = root.best_move()
while node is not None:
sin = re_sub(rules[r].p, rules[r].s, sin)
r, node = node.best_move()
```

Putting this all together, our complete function for running multiple paths and producing a classified result looks like:

```
def play(rules, sin, pout, paths, max_depth=1000):
# start with root move
root = explored_move(len(rules))
# explore random paths
for p in range(paths):
# reset state for new path
s = sin
cur_move = root
# make depth number of moves
for d in range(max_depth):
# play a random move
next_move = cur_move
while next_move == cur_move:
next_move, s_next = try_play_move(rules, cur_move, s)
# if no legal move then stop this path
if next_move is None:
break
# move on to next move
cur_move = next_move
s = s_next
# if matches final state then record win and stop path
if re_match(pout, s):
cur_move.win()
break
# replay move sequence with most wins
r, node = root.best_move()
while node is not None:
sin = re_sub(rules[r].p, rules[r].s, sin)
r, node = node.best_move()
# return resulting string
return sin
```

## Next time

We’ll look at using our shiny new classifier to improve the quality of our ingredient quantity classification.