Snake training: Java-Klassen mit Python

Wie man Java-Klassen in Python benutzt

Keine Kommentare

Generell sollte man zwar für jedes Problem das passende Werkzeug nutzen. Aber oftmals wird man gezwungen, den Hammer Java zu nutzen, weil der Rest des Hauses mit diesem Hammer gebaut wurde. Eine moderne Lösung dieses Problems ist natürlich die Microservice-Architektur: unabhängige Microservices, die je eine Aufgabe erledigen und in der jeweils am besten passenden Sprache geschrieben sind.

Aber was tun, wenn der Monolith bereits besteht oder das Projekt nicht groß genug ist, um die hohe Komplexität von Microservices zu rechtfertigen? Nun, für diesen Fall, in dem hohe Kopplung unvermeidbar oder sogar erwünscht ist, möchte ich hier eine Herangehensweise vorstellen. Wir werden lernen, wie wir das Machine-Learning-Ökosystem von Python verwenden, um Reinforcement Learning auf ein in Java implementiertes System anzuwenden. Das in Python trainierte Modell können wir später wieder in Java laden und benutzen. Python wird hier also nur während des Trainings verwendet und nicht im Produktiveinsatz. Der Vorteil ist, dass sich der Data Scientist freut, seine liebsten Werkzeuge verwenden zu können.

Und da es um Python geht: Welches Beispielproblem würde sich besser eignen als das klassische Spiel Snake? (Die Antwort auf diese rhetorische Frage ist vermutlich: „Eine Anspielung auf Monty Python.“ Aber mir ist kein simples Problem eingefallen, das sich um einen fliegenden Zirkus dreht.)

Der komplette Quellcode unseres Beispiels ist auf GitHub verfügbar.

Snake in Java

Unsere Ausgangssituation ist, dass wir ein Java-Programm haben, in dem die Spiellogik von Snake implementiert ist: Es ist immer ein Stück Futter auf dem Spielfeld. Wenn die Schlange Futter erreicht, wird sie länger und neues Futter erscheint. Wenn die Schlange eine der Wände oder sich selbst beißt, ist das Spiel zuende.

Unser Ziel ist es, ein neuronales Netz zu trainieren, das die Schlange so steuert, dass sie möglichst lang ist, bevor sie einen Fehler macht und das Spiel vorbei ist. Dazu brauchen wir einen Tensor, der den aktuellen Zustand des Spiels darstellt und als Input in das neuronale Netz gefüttert wird, damit es daraus den besten nächsten Schritt vorhersagt. Um dieses Beispiel simpel zu halten, ist unser Tensor nur ein Vektor mit sieben Elementen, die entweder 0 oder 1 sein können: Die ersten vier signalisieren, ob das Futter rechts, links, vor oder hinter der Schlange ist und die nächsten drei Werte signalisieren, ob das Feld links, geradeaus oder rechts von einer Wand oder einem Teil der Schlange besetzt sind.

public class SnakeLogic {
    Coordinate head; // position of the snake's head
    Coordinate food; // position of the food
    Move headDirection; // direction in which the head points
 
    public boolean[] trainingState() {
        boolean[] state = new boolean[7];
 
        // get the angle from the head to the food,
        // depending on the direction of movement `headDirection`
        double alpha = angle(head, headDirection, food);
 
        state[0] = isFoodFront(alpha);
        state[1] = isFoodLeft(alpha);
        state[2] = isFoodRight(alpha);
        state[3] = isFoodBack(alpha);
 
        // check if there is danger on these sites
        state[4] = danger(head.left(headDirection));
        state[5] = danger(head.straight(headDirection));
        state[6] = danger(head.right(headDirection));
 
        return state;
    }
 
    // omitted other fields and methods for clarity
    // find them at https://github.com/surt91/autosnake
}

Einerseits müssen wir diese Methode während des Trainings des neuronales Netzes von Python aus aufrufen können. Andererseits benötigen wir sie auch später im Produktiveinsatz in unserem Java-Programm, um dem fertig trainierten Netz eine Entscheidungsgrundlage zu liefern.

Java-Klassen in Python

Hier kommt JPype ins Spiel! Das Importieren einer Klasse aus Java — ohne dass wir die Java-Seite des Codes anfassen müssten — gelingt einfach durch:

import jpype
import jpype.imports
from jpype.types import *
 
# launch the JVM
jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
 
# import the Java module
from me.schawe.autosnake import SnakeLogic
 
# construct an object of the `SnakeLogic` class ...
width, height = 10, 10
snake_logic = SnakeLogic(width, height)
 
# ... and call a method on it
print(snake_logic.trainingState())

JPype startet dabei eine eigene JVM im selben Prozess, der auch Python ausführt, und lässt das Python-Programm mit ihr über das Java Native Interface (JNI) kommunizieren. Das kann man sich, etwas vereinfacht, so vorstellen wie das Aufrufen von Funktionen
aus dynamischen Bibliotheken (für eingefleischte Pythonistas ist möglicherweise der Vergleich mit dem Modul ctypes hilfreich). JPype macht dies allerdings sehr komfortabel, indem es die Abbildung von Java- und Python-Klassen aufeinander transparent übernimmt.

Es sei jedoch noch erwähnt, dass es überraschend viele Projekte mit diesem Ziel und unterschiedlichen Stärken, Schwächen und Anwendungsgebieten gibt. Stellvertretend seien Jython und Py4J erwähnt:

Jython führt einen Python-Interpreter direkt in der JVM aus, sodass die gleichen Datenstrukturen effizient von Python und Java aus manipuliert werden können. Allerdings bringt das gleichzeitig Einschränkungen mit sich, was die Nutzung nativer Python-Bibliotheken angeht — da wir numpy und tensorflow nutzen wollen, scheidet diese Option also aus.

Py4J steht eher auf der anderen Seite des Spektrums. Auf der Java-Seite startet es einen Socket, über den es mit der Python-Seite kommuniziert. Der Vorteil ist, dass sich beliebig viele Python-Prozesse mit einem lang laufenden Java-Prozess verbinden können — oder umgekehrt ein Python-Prozess mit vielen JVMs, sogar über das Netzwerk. Der Nachteil ist, dass die Kommunikation über den Socket vergleichsweise langsam ist.

Das Training

Nun, da wir aus Python Zugriff auf unsere Java-Klassen haben, können wir das Deep-Learning-Framework unserer Wahl — hier Keras — nutzen, um ein Modell zu erstellen und zu trainieren. Da wir in diesem Fall eine Schlange trainieren wollen, möglichst
viele Punkte zu sammeln, werden wir einen Reinforcement-Learning-Ansatz anwenden.

Reinforcement Learning bedeutet grundsätzlich, dass wir einen Agenten mit einem Environment interagieren lassen, ihn für gute Entscheidungen belohnen und für schlechte bestrafen. Diese Disziplin sorgt häufiger für Aufsehen, beispielsweise
durch das Spielen von klassischen Atari-Spielen oder Go.

Für unseren Fall bietet es sich an, ein Trainings-Environment zu schreiben, das sich eng an den Gyms von OpenAI orientiert, da diese für Reinforcement-Learning einen Quasi-Standard darstellen.

Dafür brauchen wir zunächst eine Methode step, die eine Aktion action entgegennimmt, einen Zeitschritt simuliert und das Ergebnis der Aktion zurückgibt. Die action ist dabei der Output des neuronalen Netzes und bestimmt, ob die Schlange sich nach links oder rechts dreht oder sich weiter geradeaus bewegt. Das zurückgegebene Ergebnis besteht aus

  1. state, dem neuen Zustand (unser siebener Vektor),
  2. reward, der Bewertung der Aktion: 1 wenn die Schlange Futter gefressen hat, -1 wenn die Schlange sich selbst oder eine Wand gebissen hat und sonst 0. Und
  3. done, ob die Partie vorbei ist, also ob die Schlange sich selbst oder eine Wand gebissen hat. Sowie
  4. einem Dictionary mit Debugging-Informationen, das wir in unserem Fall einfach leer lassen.

Außerdem benötigen wir eine Methode reset, um eine neue Partie zu starten, die ebenfalls den neuen Zustand zurückgibt.

Beide Methoden können wir dank unserer existierenden Java-Klasse sehr einfach schreiben:

import jpype
import jpype.imports
from jpype.types import *
 
# Launch the JVM
jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
 
# import the Java module
from me.schawe.autosnake import SnakeLogic
 
 
class Snake:
    def __init__(self):
        width, height = 10, 10
        # `snakeLogic` is a Java object, such that we can call
        # all its methods. This is also the reason why we
        # name it in camelCase instead of the snake_case
        # convention of Python.
        self.snakeLogic = SnakeLogic(width, height)
 
    def reset(self):
        self.snakeLogic.reset()
 
        return self.snakeLogic.trainingState()
 
    def step(self, action):
        self.snakeLogic.turnRelative(action)
        self.snakeLogic.update()
 
        state = self.snakeLogic.trainingState()
 
        done = False
        reward = 0
        if self.snakeLogic.isGameOver():
            reward = -1
            done = True
        elif self.snakeLogic.isEating():
            reward = 1
 
        return state, reward, done, {}

Diese Trainingsumgebung können wir nun mit minimalem Aufwand in das erste Beispiel aus der Keras-Dokumentation für Reinforcement Learning einbauen und das leicht angepasste Skript direkt nutzen, um mit dem Training zu beginnen:

Beginn des TrainingsNach 200 gespielten Spielen
Nach 400 gespielten SpielenNach 3000 gespielten Spielen
Spätestens seit Rocky wissen wir, dass ein Training nur mit eineer Trainings-Montage gut ist.

Die Schlange lernt tatsächlich dazu! Innerhalb weniger Minuten läuft sie zielstrebig auf das Futter zu und weicht Wänden aus — allerdings fängt sie sich gerne selbst. Für unsere Zwecke soll dieses Verhalten aber vorerst ausreichen.

Modell in Java laden

Um den Kreis zu schließen, laden wir unser trainiertes Modell mit deeplearning4j in Java …

// https://deeplearning4j.konduit.ai/deeplearning4j/how-to-guides/keras-import
public class Autopilot {
    ComputationGraph model;
 
    public Autopilot(String pathToModel) {
        try {
            model = KerasModelImport.importKerasModelAndWeights(pathToModel, false);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
 
    // infer the next move from the given state
    public int nextMove(boolean[] state) {
        INDArray input = Nd4j.create(state).reshape(1, state.length);
        INDArray output = model.output(input)[0];
 
        int action = output.ravel().argMax().getInt(0);
 
        return action;
    }
}

… wo wir die selben Methoden aufrufen, die wir während des Training genutzt haben, um die Schlange zu steuern.

public class SnakeLogic {
    Autopilot autopilot = new Autopilot("path/to/model.h5");
 
    public void update() {
        int action = autopilot.nextMove(trainingState());
        turnRelative(action);
 
        // rest of the update omitted
    }
 
    // further methods omitted
}

Fazit

Unter dem Strich ist es also überraschend einfach Java und Python gemeinsam zu nutzen, was vor allem zur Prototypen-Entwicklung sehr effizient sein kann.

Und es muss nicht direkt Deep Learning sein. Durch die sehr einfache Anwendbarkeit gibt es sicherlich auch Potential, diesen Ansatz zu wählen, um etwas explorative Datenanalyse auf der Datenbank unter Verwendung der gesamten Geschäftslogik in einem iPython Notebook zu betreiben.

Was unser Anwendungsbeispiel angeht: Dafür, dass wir keinerlei Gedanken in das Modell gesteckt haben, ist das Ergebnis überraschend gut. Für bessere Ergebnisse müsste man vermutlich das ganze Spielfeld in das neuronale Netz füttern und wir müssten uns etwas mehr Gedanken über das Modell machen. Eine kurze Google-Recherche zeigt, dass es anscheinend Modelle gibt, die ein perfektes Spiel Snake spielen können, sodass jedes einzelne Feld belegt ist. Für Snake ist es möglicherweise jedoch sinnvoller, das neuronale Netz zwischen den Ohren zu verwenden, um eine perfekte Strategie zu entwickeln. Zum Beispiel wird es immer ein perfektes Spiel, wenn die Schlange sich immer auf einem Hamilton-Pfad (ein Pfad, der alle Gitterplätze, ausgenommen die von der Schlange belegten, genau einmal besucht) zwischen Kopf und Schwanzende bewegt. Wie man effizient diese Hamilton-Pfade findet, ist dem Leser als Übung überlassen.

Über 1.000 Abonnenten sind up to date!

Die neuesten Tipps, Tricks, Tools und Technologien.
Jede Woche direkt in deine Inbox.

Kostenfrei anmelden und immer auf dem neuesten Stand bleiben!
(Keine Sorge, du kannst dich jederzeit abmelden.)

Artikel von Hendrik Schawe

How to use Java classes in Python

Weitere Inhalte zu Deep Learning

Kommentieren

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert.