Snake training: Java classes in Python

How to use Java classes in Python

No Comments

There is an old truism: “Use the right tool for the job.” However, in building software, we are often forced to nail in screws, just because the rest of the application was built with the figurative hammer Java. Of course, one of the preferred solutions to this problem is microservices, which each handle one task and can be written in the most suitable language.

But what to do if the monolith already exists, or the project is not large enough to justify the increased complexity of microservices? Well, in this case, if tight coupling is unavoidable or even preferred, we can use the approach I am going to show in this blog post. We will learn how to use the machine learning ecosystem of Python to apply reinforcement learning to a system implemented in Java. After training, we can load the model into Java and use it. Therefore, we only use Python during the training and not in production. What’s best about this approach is that it ensures a happy data scientist who can use the right tools for the job.

And since this is about Python: What would be a better toy example than the classic game Snake? (The answer to this rhetorical question is, of course: “Some reference to Monty Python.” But I really could not think of a simple problem about a flying circus.)

The complete source code of our example is available on GitHub.

Snake in Java

We are starting with a Java program implementing the game logic of Snake: there is always a piece of food on the field. Whenever the snake reaches the food, it grows and new food appears. If the snake bites itself or a wall, the game ends.

Our objective is to train a neural net to steer the snake such that the snake eats as much food as possible before it makes a mistake and the game ends. First, we need a tensor which represents the current state of the game. It acts as the input of our neural net, such that the net can use it to predict the best next step to take. To keep this example simple, our tensor is just a vector of seven elements, which can either be 1 or 0: the first four indicate if the food is right, left, in front of or behind the snake and the next three entries signal if the fields left, in front of and right of the snake’s head are blocked by a wall or the tail of the snake.

public class SnakeLogic {
    Coordinate head; // position of the snake's head
    Coordinate food; // position of the food
    Move headDirection; // direction in which the head points
 
    public boolean[] trainingState() {
        boolean[] state = new boolean[7];
 
        // get the angle from the head to the food,
        // depending on the direction of movement `headDirection`
        double alpha = angle(head, headDirection, food);
 
        state[0] = isFoodFront(alpha);
        state[1] = isFoodLeft(alpha);
        state[2] = isFoodRight(alpha);
        state[3] = isFoodBack(alpha);
 
        // check if there is danger on these sites
        state[4] = danger(head.left(headDirection));
        state[5] = danger(head.straight(headDirection));
        state[6] = danger(head.right(headDirection));
 
        return state;
    }
 
    // omitted other fields and methods for clarity
    // find them at https://github.com/surt91/autosnake
}

We will need this method on two occasions. First, during the training, where we will call it directly from Python. And later in production, where we will call it from our Java program to give the trained net a basis for making a decision.

Java classes in Python

Enter JPype! The import of a Java class — without any changes to the Java sources — can be accomplished simply with the following code:

import jpype
import jpype.imports
from jpype.types import *
 
# launch the JVM
jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
 
# import the Java module
from me.schawe.autosnake import SnakeLogic
 
# construct an object of the `SnakeLogic` class ...
width, height = 10, 10
snake_logic = SnakeLogic(width, height)
 
# ... and call a method on it
print(snake_logic.trainingState())

JPype starts a JVM in the same process as the Python interpreter and lets them communicate using the Java Native Interface (JNI). One can think about it, in a simplified way, like calling functions from dynamic libraries (experienced Pythonistas may find a comparison to the module ctypes helpful.) But JPype does this in a very comfortable way and automatically maps Java classes on Python classes.

It should also be noted that there is a surprising number of projects with this objective, each with their own strengths and weaknesses. As representatives, we will quickly look at Jython and Py4J.

Jython executes a Python Interpreter directly in the JVM, such that Python and Java can very efficiently use the same data structures. But this comes with a few drawbacks for the usage of native Python libraries — since we will use numpy and tensorflow, this is not an option for us.

Py4J is on the opposite side of the spectrum. It starts a socket in the Java code, over which it can communicate with Python programs. The advantage is that an arbitrary number of Python processes can connect to a long-running Java process — or the other way around, one Python process can connect to many JVMs, even over the network. The downside is a larger overhead of the socket communication.

The training

Now that we can access our Java classes, we can use the deep learning framework of our choice — in our case, Keras — to create and train a model. Since we want to train a snake to collect the maximum amount of food, we choose a reinforcement learning approach.

In reinforcement learning an agent interacts with an environment and is rewarded for good decisions and punished for bad decisions. In the past, this discipline has drawn quite some media attention for playing classic Atari games or Go.

For our application, it makes sense to write a class that adheres closely to the OpenAI Gyms, since they are a de facto standard for reinforcement learning.

Therefore we need a method step, which takes an action, simulates a time step, and returns the result of the action. The action is the output of the neural net and suggests whether the snake should turn left or right or not at all. The returned result consists of

  1. state, the new state (our vector with seven elements),
  2. reward our valuation of the action: 1 if the snake could eat food in this step, -1 if the snake bit itself or a wall and else 0. And
  3. done, an indicator whether the round is finished, i.e. if the snake bit itself or a wall. Also
  4. a dictionary with debugging information, which we just leave empty.

Further, we need a method reset to start a new round. It should also return the new state.

Both methods are easy to write thanks to our already existing Java classes:

import jpype
import jpype.imports
from jpype.types import *
 
# Launch the JVM
jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
 
# import the Java module
from me.schawe.autosnake import SnakeLogic
 
 
class Snake:
    def __init__(self):
        width, height = 10, 10
        # `snakeLogic` is a Java object, such that we can call
        # all its methods. This is also the reason why we
        # name it in camelCase instead of the snake_case
        # convention of Python.
        self.snakeLogic = SnakeLogic(width, height)
 
    def reset(self):
        self.snakeLogic.reset()
 
        return self.snakeLogic.trainingState()
 
    def step(self, action):
        self.snakeLogic.turnRelative(action)
        self.snakeLogic.update()
 
        state = self.snakeLogic.trainingState()
 
        done = False
        reward = 0
        if self.snakeLogic.isGameOver():
            reward = -1
            done = True
        elif self.snakeLogic.isEating():
            reward = 1
 
        return state, reward, done, {}

Now, we can easily insert this training environment into the first example for reinforcement learning of the Keras documentation and directly use it to start the training:

Begin of the trainingAfter 200 played Games
After 400 played GamesAfter 3000 played Games
Since Rocky it is clear that a training montage is necessary for success.

The snake learns! Within a few minutes, it begins to directly move towards the food and avoids the walls — but it still tends to trap itself quickly. For our purposes this should suffice for now.

Load the model in Java

Now we come full circle and load the trained model into Java using deeplearning4j

// https://deeplearning4j.konduit.ai/deeplearning4j/how-to-guides/keras-import
public class Autopilot {
    ComputationGraph model;
 
    public Autopilot(String pathToModel) {
        try {
            model = KerasModelImport.importKerasModelAndWeights(pathToModel, false);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
 
    // infer the next move from the given state
    public int nextMove(boolean[] state) {
        INDArray input = Nd4j.create(state).reshape(1, state.length);
        INDArray output = model.output(input)[0];
 
        int action = output.ravel().argMax().getInt(0);
 
        return action;
    }
}

… where we call the same methods used during training to steer the snake.

public class SnakeLogic {
    Autopilot autopilot = new Autopilot("path/to/model.h5");
 
    public void update() {
        int action = autopilot.nextMove(trainingState());
        turnRelative(action);
 
        // rest of the update omitted
    }
 
    // further methods omitted
}

Conclusion

It is surprisingly easy to make Java and Python work hand in hand, which can be especially valuable when developing prototypes.

What’s more, it does not have to be deep learning. Since the connection between Java and Python is so easy to use, there certainly is potential to apply this approach to facilitate explorative data analysis on a database using the full business logic in an iPython notebook.

Regarding our toy example: Given that we did not spend a single thought on the model, the result is surprisingly good. For better results, one probably would have to use the full field as input and think a bit more about the model. Quick googling shows that
apparently there are models which can play a perfect game of Snake, such that the snake occupies every single site. For Snake, it might be more sensible to use the neural net between one’s ears to think of a perfect strategy. For example, we can ensure a
perfect game if the snake always moves on a Hamilton path between its head and the tip of its tail (i.e. a path which visits all sites except those occupied by the snake). How to find Hamilton paths efficiently will be left to the reader as an exercise.

Comment

Your email address will not be published. Required fields are marked *