Training a model in Python and loading in TensorFlow.js - TensorFlow.js Tutorial p.4




Welcome to part 4 of the TensorFlow.js series, where we're going to be working on the challenge of training a model in Python, and then loading that trained model from Python back into your TensorFlow.js application. To start, we need to first train a Python model:

To begin, we need some training data. We could use Python for this, but we opted to use javascript for this. We ran it until we had 100,000 training samples. You can run it if you like, but I have also hosted the training data here: Pong AI training data.

If you do want to build your own training data, build more data...etc, here are the training files:

ponggame-fixed-strike-noai-collector.js

ponggame-fixed-strike-noai-collector.html

The above scripts are used to just put 2 "computers" against eachother, running as fast as possible, to create the training data. Sometimes it gets stuck in training and you need to manually change the ball speed in the console...or change the code to randomly change the speed from time to time.

Okay, so we've got training data with 100,000 samples in json form. What we'd like to do is train a model in Python on this data, and then output it to something we can use in TensorFlow.js. Our json data looks like: {"xs":[[152,241,124,442,121,244],...], "ys":[[1,0,0],...]}, saved in training_data-100k.json. The first thing we need to do is load in this data in Python:

import json
import numpy as np


with open('training_data-100k.json') as f:
    data = json.load(f)
    xs = np.array(data['xs'])
    ys = np.array(data['ys'])

Then, as we try to find a decent model to use, we need to split this data into training and testing groups:

x_train = xs[:-10000]
y_train = ys[:-10000]
x_test = xs[-10000:]
y_test = ys[-10000:]

Now we are ready to build our keras model:

Start with the following imports:

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout

Then below our previous code:

model = Sequential()
model.add(Dense(64, activation='relu', input_dim=6))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))

Look familiar? Note that you must include the input_dim for the input layer, but Keras figures out the rest for you.

Now, just like in TensorFlow.js, we need to compile and then fit the data.

adam = keras.optimizers.Adam(lr=0.001)

model.compile(loss='categorical_crossentropy',
              optimizer=adam,
              metrics=['accuracy'])

model.fit(x_train, y_train,
          epochs=10,
          batch_size=128)

Finally, we want to save this model when we're done with it, and maybe see the results of the out-of-sample testing:

score = model.evaluate(x_test, y_test, batch_size=128)
print(score)
model.save("Keras-64x2-10epoch")

Now that we have this Keras model, we'd like to convert it to be used within our actual pong application.

To start, we need to install tensorflowjs for python:

pip install tensorflowjs

Next, we can make the following new import in our training script:

import tensorflowjs as tfjs

Then, rather than doing model.save, which you can and might as well do too, you can do:

tfjs.converters.save_keras_model(model, "tfjsmodel")

You may still want to save your keras model too, just in case you want to return to it later, rather than re-training it.

Full code:

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
import json
import numpy as np
import tensorflowjs as tfjs

with open('training_data-100k.json') as f:
    data = json.load(f)
    xs = np.array(data['xs'])
    ys = np.array(data['ys'])


x_train = xs[:-10000]
y_train = ys[:-10000]
x_test = xs[-10000:]
y_test = ys[-10000:]

model = Sequential()
model.add(Dense(64, activation='relu', input_dim=6))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))

adam = keras.optimizers.Adam(lr=0.001)

model.compile(loss='categorical_crossentropy',
              optimizer=adam,
              metrics=['accuracy'])

model.fit(x_train, y_train,
          epochs=10,
          batch_size=128)

score = model.evaluate(x_test, y_test, batch_size=128)
print(score)
model.save("Keras-64x2-10epoch")
tfjs.converters.save_keras_model(model, "tfjsv3")

Now you should have everything you need in this local tfjsmodel directory. Next, we just need to import that model!

Finally, to load this model in TensorFlow.js, we just need to use the following in our JavaScript:

model =  await tf.loadModel('https://path/to/model');

For example:

model =  await tf.loadModel('https://news.r6siege.cn/static/downloads/machine-learning-data/tfjsversion/model.json');

The above only works if you're also running the TensorFlow.js code also on news.r6siege.cn. Otherwise, you will get a CORS error (Cross-Origin Resource Sharing). Thus, if you want to do this, you need to either host the file on the same server you're running it from (loadModel uses an http/https request), or you need the server hosting it to have CORS enabled. I didn't want to do that with news.r6siege.cn for security reasons, but I did it for HKinsley.com, so you can instead load the model from: https://hkinsley.com/static/tfjsmodel/model.json, doing:

model =  await tf.loadModel('https://hkinsley.com/static/tfjsmodel/model.json');

That should work for you just fine. Here's the full code of what you'd need to run the pong ai locally, or you could even host it if you wanted:

main.html

<h4>TensorFlow.js implementation of a pong-playing AI</h4>
<p>You play as the bottom paddle, use arrow keys to move.</p>
<h6>Rules: Don't read/judge the js.</h6>
<div id='mainContent'></div>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2"> </script>
<script src="pongai.js"></script>
<p id="playing"></p>

pongai.js

// the actual PONG javascript code is from: https://robots.thoughtbot.com/pong-clone-in-javascript
// We have only modified it to include some javascript AI!

// init function
async function init(){

    //model =  await tf.loadModel('indexeddb://my-model-1');
    model =  await tf.loadModel('https://hkinsley.com/static/tfjsmodel/model.json');
    //model =  await tf.loadModel('tfjsversion/model.json');
    console.log('model loaded from storage');
    computer.ai_plays = true;

     if(computer.ai_plays){
        document.getElementById("playing").innerHTML = "Playing: AI";
    }else{
        document.getElementById("playing").innerHTML = "Playing: Computer";
    }

    // start a game
    animate(step);
}

// set game animation speed (game clock)
var animate = window.requestAnimationFrame || window.webkitRequestAnimationFrame || window.mozRequestAnimationFrame || function (callback) {
        window.setTimeout(callback, 1000 / 60)
    };

// create canvas
var canvas = document.createElement("canvas");
var width = 400;
var height = 600;
canvas.width = width;
canvas.height = height;
var context = canvas.getContext('2d');

// create game "objects"
var player = new Player();
var computer = new Computer();
var ball = new Ball(200, 300);
var ai = new AI();

// pressed keys
var keysDown = {};

// renders board
var render = function () {
    context.fillStyle = "#000000";
    context.fillRect(0, 0, width, height);
    player.render();
    computer.render();
    ball.render();
};

// updates game state
var update = function () {

    // update player position
    player.update(ball);

    // update "computer" position
    // ai-based
    if(computer.ai_plays){
        move = ai.predict_move();
        computer.ai_update(move);
    // or rule-based if we don;t have any model yet
    }else{
        move = ai.predict_move();
        computer.ai_update(move);
    }

    // update ball position
    ball.update(player.paddle, computer.paddle);

    // add training data from current frame to training set
    ai.save_data(player.paddle, computer.paddle, ball)
};

// main game loop
var step = function () {
    update();
    render();
    animate(step); // runs that loop again after a "tick"
};

// paddle object
function Paddle(x, y, width, height) {
    this.x = x;
    this.y = y;
    this.width = width;
    this.height = height;
    this.x_speed = 0;
    this.y_speed = 0;
}

// renders paddle on a board
Paddle.prototype.render = function () {
    context.fillStyle = "#59a6ff";
    context.fillRect(this.x, this.y, this.width, this.height);
};

// moves paddle by x and y pixels (y is always 0 now)
Paddle.prototype.move = function (x, y) {

    // update position and speed
    this.x += x;
    this.y += y;
    this.x_speed = x;
    this.y_speed = y;

    // check if not out of the board
    if (this.x < 0) {
        this.x = 0;
        this.x_speed = 0;
    } else if (this.x + this.width > 400) {
        this.x = 400 - this.width;
        this.x_speed = 0;
    }
};

// computer player object
function Computer() {
    this.paddle = new Paddle(0, 10, 50, 10);
    //this.ai_plays = false; // will be set to true whenever ai model will be ready
}

// renders computer paddle ona  board
Computer.prototype.render = function () {
    this.paddle.render();
};

// updates computer paddle position - rule-based (simply follows a ball)
Computer.prototype.update = function (ball) {

    // calculate difference in pixels between paddle and ball (cap to 5 pixels - max speed of paddle)
    var x_pos = ball.x;
    var diff = -((this.paddle.x + (this.paddle.width / 2)) - x_pos);
    if (diff < 0 && diff < -4) {
        diff = -5;
    } else if (diff > 0 && diff > 4) {
        diff = 5;
    }

    // move paddle
    this.paddle.move(diff, 0);

    // check if paddle is not outside of the board
    if (this.paddle.x < 0) {
        this.paddle.x = 0;
    } else if (this.paddle.x + this.paddle.width > 400) {
        this.paddle.x = 400 - this.paddle.width;
    }
};

// updates computer paddle position - ai-based (ai calls it later in a code)
Computer.prototype.ai_update = function (move = 0) {
    this.paddle.move(4 * move, 0);
};

// player object
function Player() {
    this.paddle = new Paddle(0, 580, 50, 10);
}

// renders player paddle
Player.prototype.render = function () {
    this.paddle.render();
};

// updates player paddle position
//Player.prototype.update = Computer.prototype.update;


Player.prototype.update = function () {
    for (var key in keysDown) {
        var value = Number(key);
        if (value == 37) {
            this.paddle.move(-4, 0);
        } else if (value == 39) {
            this.paddle.move(4, 0);
        } else {
            this.paddle.move(0, 0);
        }
    }
};

// ball object
function Ball(x, y) {
    this.x = x;
    this.y = y;
    this.x_speed = Math.random()*4+1;
    this.y_speed = Math.random()*3+2;
    this.player_strikes = false;
    this.ai_strikes = false;
}

// renders ball on a table
Ball.prototype.render = function () {
    context.beginPath();
    context.arc(this.x, this.y, 5, 2 * Math.PI, false);
    context.fillStyle = "#ddff59";
    context.fill();
};

// updates ball position
Ball.prototype.update = function (paddle1, paddle2, new_turn) {

    // update speed and upper/lower point of a ball on a table
    this.x += this.x_speed;
    this.y += this.y_speed;
    var top_x = this.x - 5;
    var top_y = this.y - 5;
    var bottom_x = this.x + 5;
    var bottom_y = this.y + 5;

    // check if ball is not outside of a table
    // bounce off the side walls
    if (this.x - 5 < 0) {
        this.x = 5;
        this.x_speed = -this.x_speed;
    } else if (this.x + 5 > 400) {
        this.x = 395;
        this.x_speed = -this.x_speed;
    }

    // if ball hits upper and lower walls - reset ball (score)
    if (this.y < 0 || this.y > 600) {
        this.x_speed = Math.random()*4+1;
        this.y_speed = Math.random()*3+2;
        this.x = 200;
        this.y = 300;
        ai.new_turn();
    }

    // move ball on a table, update angle and speed, calculate new position
    this.player_strikes = false;
    this.ai_strikes = false;
    if (top_y > 300) {
        if (top_y < (paddle1.y + paddle1.height) && bottom_y > paddle1.y && top_x < (paddle1.x + paddle1.width) && bottom_x > paddle1.x) {
            this.y_speed = -3;
            this.x_speed += (paddle1.x_speed / 2);
            this.y += this.y_speed;
            this.player_strikes = true;
            console.log('player strikes');
        }
    } else {
        if (top_y < (paddle2.y + paddle2.height) && bottom_y > paddle2.y && top_x < (paddle2.x + paddle2.width) && bottom_x > paddle2.x) {
            this.y_speed = 3;
            this.x_speed += (paddle2.x_speed / 2);
            this.y += this.y_speed;
            this.ai_strikes = true;
            console.log('ai strikes');
        }
    }
};

// AI object
function AI(){
    this.previous_data = null;                  // data from previous frame
    this.training_data = [[], [], []];          // empty training dataset
    this.training_batch_data = [[], [], []];    // empty batch (dataset to be added to training data)
    this.previous_xs = null;                    // input data from previus frame
    this.turn = 0;                              // number of turn
    this.grab_data = true;                      // enables/disables data grabbing
    this.flip_table = true;                     // flips table
    this.keep_trainig_records = true;           // keep some number of training records instead of discardin them each session
    this.training_records_to_keep = 100000;     // number of training records to keep
    this.first_strike = true;                   // first strike flag (to ommit data)
}

// saves data from current frame of a game
AI.prototype.save_data = function(player, computer, ball){

    // return if grabbing is disabled
    if(!this.grab_data)
        return;

    // fresh turn, just fill initial data in
    if(this.previous_data == null){
        this.previous_data = [player.x, computer.x, ball.x, ball.y];
        return;
    }

    // if ai strikes, start recording data - empty batch
    if(ball.ai_strikes){
        this.training_batch_data = [[], [], []];
        console.log('emtying batch')
    }

    // create current data object [player_x, computer_x, ball_x, ball_y]
    // and embedding index (0 - left, 1 - no move, 2 - right)
    data_xs = [player.x, computer.x, ball.x-60, ball.y];
    index = (player.x < this.previous_data[0])?0:((player.x == this.previous_data[0])?1:2);

    // save data as [...previous data, ...current data]
    // result - [old_player_x, old_computer_x, old_ball_x, old_ball_y, player_x, computer_x, ball_x, ball_y]
    this.previous_xs = [...this.previous_data, ...data_xs];
    // add data to training set depending on index value (depending if that data relates to the move to the left, no move or move to the right)
    // only player and ball position
    this.training_batch_data[index].push([this.previous_xs[0], this.previous_xs[2], this.previous_xs[3], this.previous_xs[4], this.previous_xs[6], this.previous_xs[7]]);
    // set current data as previous data for next frame
    this.previous_data = data_xs;

    // if player strikes, add batch to training data
    if(ball.player_strikes){
        if(this.first_strike){
            this.first_strike = false;
            this.training_batch_data = [[], [], []];
            console.log('emtying batch');
        }else{
            for(i = 0; i < 3; i++)
                this.training_data[i].push(...this.training_batch_data[i]);
            this.training_batch_data = [[], [], []];
            console.log('adding batch');
        }
    }
}

// runs every turn
AI.prototype.new_turn = function(){

    // clean previus data, we are starting fresh
    this.first_strike = true;
    this.training_batch_data = [[], [], []];
    this.previous_data = null;
    this.turn++;
    console.log('new turn: ' + this.turn);

    //computer.ai_plays = !computer.ai_plays;
    if(computer.ai_plays){
        document.getElementById("playing").innerHTML = "Playing: AI";
    }else{
        document.getElementById("playing").innerHTML = "Playing: Computer";
    }

    // after x turn
    /*if(this.turn > 9){

        // tarin a model
        this.train();

        // allow ai to play (as we have a trained model)
        //computer.ai_plays = true;

        // empty training dataset
        this.reset();
    }*/
}

// empties training data
AI.prototype.reset = function(){
    this.previous_data = null;
    if(!this.keep_trainig_records)
        this.training_data = [[], [], []];
    this.turn = 0;

    if(computer.ai_plays){
        document.getElementById("playing").innerHTML = "Playing: AI";
    }else{
        document.getElementById("playing").innerHTML = "Playing: Computer";
    }

    console.log('reset')
    console.log('emtying batch')
}

// trains a model
AI.prototype.train = function(){

    // first we have to balance a data
    console.log('balancing');
    document.getElementById("playing").innerHTML = "Training";

    // trim data and find minimum number of training records in data for all 3 embeddings
    if(this.keep_trainig_records){
        for(i = 0; i < 3; i++){
            if(this.training_data[i].length > this.training_records_to_keep)
                this.training_data[i] = this.training_data[i].slice(
                    Math.max(0, this.training_data[i].length - this.training_records_to_keep),
                    this.training_data[i].length
                );
        }
    }
    len = Math.min(this.training_data[0].length, this.training_data[1].length, this.training_data[2].length);
    console.log(this.training_data);
    if(!len){
        console.log('no data to train on');
        return;
    }

    data_xs = [];
    data_ys = [];

    // now we need to trim data so every embedding will contain exactly the same amount of training records
    // than randomize that data
    // and create embedding records one embedding record for every input data record
    // finally add training data records and embedding records to common tables (for training)
    // tf.fit() will do final data shuffle for us
    for(i = 0; i < 3; i++){
        data_xs.push(...this.training_data[i].slice(0, len)
            .sort(()=>Math.random()-0.5).sort(()=>Math.random()-0.5));      // trims training data to 'len' length and shuffle it
        data_ys.push(...Array(len).fill([i==0?1:0, i==1?1:0, i==2?1:0]));   // creates 'len' number records of embedding data
                                                                            // either [1, 0 0] for left, [0, 1, 0] - for no move
                                                                            // and [0, 0, 1] for right (depending in index if training data)
    }
    //console.log(data_xs);
    //console.log(data_ys);
    document.createElement("playing").innerHTML = "Training: "+data_xs.length+" records";

    console.log('training-1');

    // create tensor from
    const xs = tf.tensor(data_xs);
    const ys = tf.tensor(data_ys);

    // "crative" way of running asynchronous code in a synchronous-like manner
    (async function() {
        console.log('training-2');

        // train a model
        let result = await model.fit(xs, ys, {
            batchSize: 32,
            epochs: 1,
            shuffle: true,
            validationSplit: 0.1,
            callbacks: {
                // print batch stats
                onBatchEnd: async (batch, logs) => {
                    console.log("Step "+batch+", loss: "+logs.loss.toFixed(5)+", acc: "+logs.acc.toFixed(5));
                },
            },
        });

        // and save it in a local storage (for later use)
        await model.save('indexeddb://my-model-1');

        // print model and validation stats
        console.log("Model: loss: "+result.history.loss[0].toFixed(5)+", acc: "+result.history.acc[0].toFixed(5));
        console.log("Validation: loss: "+result.history.val_loss[0].toFixed(5)+", acc: "+result.history.val_acc[0].toFixed(5));
    }());

    console.log('trained');
}

// inferences a move
AI.prototype.predict_move = function(){

    // but only for 2+ frame of a game (we need data from previous frame as well)
    if(this.previous_xs != null){
        // flip table so ai will see it from player's perspective
        // and try to mimic his gameplay
        // also use ionly ai's paddle positions
        data_xs = [
            width - this.previous_xs[1], width - this.previous_xs[2], height - this.previous_xs[3],
            width - this.previous_xs[5], width - this.previous_xs[6], height - this.previous_xs[7]
        ];
        // predict move
        prediction = model.predict(tf.tensor([data_xs]));
        // argmax will return embeddingL 0, 1 or 2, we need -1, 0 or 1 (left, no move, right) - decrement it and return
        // also we actually need to flip that prediction, as ai plays on top (upside-down)
        //return -(tf.argMax(prediction, 1).dataSync()-1);
        return -(tf.argMax(prediction, 1).dataSync()-1);
    }
}

// add canvas
document.body.appendChild(canvas);

// init whole code
init();

// arrow keypress events
window.addEventListener("keydown", function (event) {
    keysDown[event.keyCode] = true;
});
window.addEventListener("keyup", function (event) {
    delete keysDown[event.keyCode];
});

The next tutorial:





  • Deep Learning in the Browser with TensorFlow.js Tutorial Introduction - TensorFlow.js Tutorial p.1
  • A Basic TensorFlow.js Web Application Tutorial - TensorFlow.js Tutorial p.2
  • Pong AI - TensorFlow.js Tutorial p.3
  • Training a model in Python and loading in TensorFlow.js - TensorFlow.js Tutorial p.4