Welcome to part 4 of the TensorFlow.js series, where we're going to be working on the challenge of training a model in Python, and then loading that trained model from Python back into your TensorFlow.js application. To start, we need to first train a Python model:
To begin, we need some training data. We could use Python for this, but we opted to use javascript for this. We ran it until we had 100,000 training samples. You can run it if you like, but I have also hosted the training data here: Pong AI training data.
If you do want to build your own training data, build more data...etc, here are the training files:
ponggame-fixed-strike-noai-collector.js
ponggame-fixed-strike-noai-collector.html
The above scripts are used to just put 2 "computers" against eachother, running as fast as possible, to create the training data. Sometimes it gets stuck in training and you need to manually change the ball speed in the console...or change the code to randomly change the speed from time to time.
Okay, so we've got training data with 100,000 samples in json form. What we'd like to do is train a model in Python on this data, and then output it to something we can use in TensorFlow.js. Our json data looks like: {"xs":[[152,241,124,442,121,244],...], "ys":[[1,0,0],...]}, saved in training_data-100k.json. The first thing we need to do is load in this data in Python:
import json
import numpy as np
with open('training_data-100k.json') as f:
data = json.load(f)
xs = np.array(data['xs'])
ys = np.array(data['ys'])
Then, as we try to find a decent model to use, we need to split this data into training and testing groups:
x_train = xs[:-10000] y_train = ys[:-10000] x_test = xs[-10000:] y_test = ys[-10000:]
Now we are ready to build our keras model:
Start with the following imports:
import keras from keras.models import Sequential from keras.layers import Dense, Dropout
Then below our previous code:
model = Sequential() model.add(Dense(64, activation='relu', input_dim=6)) model.add(Dropout(0.5)) model.add(Dense(64, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(3, activation='softmax'))
Look familiar? Note that you must include the input_dim for the input layer, but Keras figures out the rest for you.
Now, just like in TensorFlow.js, we need to compile and then fit the data.
adam = keras.optimizers.Adam(lr=0.001)
model.compile(loss='categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'])
model.fit(x_train, y_train,
epochs=10,
batch_size=128)
Finally, we want to save this model when we're done with it, and maybe see the results of the out-of-sample testing:
score = model.evaluate(x_test, y_test, batch_size=128)
print(score)
model.save("Keras-64x2-10epoch")
Now that we have this Keras model, we'd like to convert it to be used within our actual pong application.
To start, we need to install tensorflowjs for python:
pip install tensorflowjs
Next, we can make the following new import in our training script:
import tensorflowjs as tfjs
Then, rather than doing model.save, which you can and might as well do too, you can do:
tfjs.converters.save_keras_model(model, "tfjsmodel")
You may still want to save your keras model too, just in case you want to return to it later, rather than re-training it.
Full code:
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
import json
import numpy as np
import tensorflowjs as tfjs
with open('training_data-100k.json') as f:
data = json.load(f)
xs = np.array(data['xs'])
ys = np.array(data['ys'])
x_train = xs[:-10000]
y_train = ys[:-10000]
x_test = xs[-10000:]
y_test = ys[-10000:]
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=6))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(3, activation='softmax'))
adam = keras.optimizers.Adam(lr=0.001)
model.compile(loss='categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'])
model.fit(x_train, y_train,
epochs=10,
batch_size=128)
score = model.evaluate(x_test, y_test, batch_size=128)
print(score)
model.save("Keras-64x2-10epoch")
tfjs.converters.save_keras_model(model, "tfjsv3")
Now you should have everything you need in this local tfjsmodel directory. Next, we just need to import that model!
Finally, to load this model in TensorFlow.js, we just need to use the following in our JavaScript:
model = await tf.loadModel('https://path/to/model');
For example:
model = await tf.loadModel('https://news.r6siege.cn/static/downloads/machine-learning-data/tfjsversion/model.json');
The above only works if you're also running the TensorFlow.js code also on news.r6siege.cn. Otherwise, you will get a CORS error (Cross-Origin Resource Sharing). Thus, if you want to do this, you need to either host the file on the same server you're running it from (loadModel uses an http/https request), or you need the server hosting it to have CORS enabled. I didn't want to do that with news.r6siege.cn for security reasons, but I did it for HKinsley.com, so you can instead load the model from: https://hkinsley.com/static/tfjsmodel/model.json, doing:
model = await tf.loadModel('https://hkinsley.com/static/tfjsmodel/model.json');
That should work for you just fine. Here's the full code of what you'd need to run the pong ai locally, or you could even host it if you wanted:
main.html
<h4>TensorFlow.js implementation of a pong-playing AI</h4> <p>You play as the bottom paddle, use arrow keys to move.</p> <h6>Rules: Don't read/judge the js.</h6> <div id='mainContent'></div> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.11.2"> </script> <script src="pongai.js"></script> <p id="playing"></p>
pongai.js
// the actual PONG javascript code is from: https://robots.thoughtbot.com/pong-clone-in-javascript
// We have only modified it to include some javascript AI!
// init function
async function init(){
//model = await tf.loadModel('indexeddb://my-model-1');
model = await tf.loadModel('https://hkinsley.com/static/tfjsmodel/model.json');
//model = await tf.loadModel('tfjsversion/model.json');
console.log('model loaded from storage');
computer.ai_plays = true;
if(computer.ai_plays){
document.getElementById("playing").innerHTML = "Playing: AI";
}else{
document.getElementById("playing").innerHTML = "Playing: Computer";
}
// start a game
animate(step);
}
// set game animation speed (game clock)
var animate = window.requestAnimationFrame || window.webkitRequestAnimationFrame || window.mozRequestAnimationFrame || function (callback) {
window.setTimeout(callback, 1000 / 60)
};
// create canvas
var canvas = document.createElement("canvas");
var width = 400;
var height = 600;
canvas.width = width;
canvas.height = height;
var context = canvas.getContext('2d');
// create game "objects"
var player = new Player();
var computer = new Computer();
var ball = new Ball(200, 300);
var ai = new AI();
// pressed keys
var keysDown = {};
// renders board
var render = function () {
context.fillStyle = "#000000";
context.fillRect(0, 0, width, height);
player.render();
computer.render();
ball.render();
};
// updates game state
var update = function () {
// update player position
player.update(ball);
// update "computer" position
// ai-based
if(computer.ai_plays){
move = ai.predict_move();
computer.ai_update(move);
// or rule-based if we don;t have any model yet
}else{
move = ai.predict_move();
computer.ai_update(move);
}
// update ball position
ball.update(player.paddle, computer.paddle);
// add training data from current frame to training set
ai.save_data(player.paddle, computer.paddle, ball)
};
// main game loop
var step = function () {
update();
render();
animate(step); // runs that loop again after a "tick"
};
// paddle object
function Paddle(x, y, width, height) {
this.x = x;
this.y = y;
this.width = width;
this.height = height;
this.x_speed = 0;
this.y_speed = 0;
}
// renders paddle on a board
Paddle.prototype.render = function () {
context.fillStyle = "#59a6ff";
context.fillRect(this.x, this.y, this.width, this.height);
};
// moves paddle by x and y pixels (y is always 0 now)
Paddle.prototype.move = function (x, y) {
// update position and speed
this.x += x;
this.y += y;
this.x_speed = x;
this.y_speed = y;
// check if not out of the board
if (this.x < 0) {
this.x = 0;
this.x_speed = 0;
} else if (this.x + this.width > 400) {
this.x = 400 - this.width;
this.x_speed = 0;
}
};
// computer player object
function Computer() {
this.paddle = new Paddle(0, 10, 50, 10);
//this.ai_plays = false; // will be set to true whenever ai model will be ready
}
// renders computer paddle ona board
Computer.prototype.render = function () {
this.paddle.render();
};
// updates computer paddle position - rule-based (simply follows a ball)
Computer.prototype.update = function (ball) {
// calculate difference in pixels between paddle and ball (cap to 5 pixels - max speed of paddle)
var x_pos = ball.x;
var diff = -((this.paddle.x + (this.paddle.width / 2)) - x_pos);
if (diff < 0 && diff < -4) {
diff = -5;
} else if (diff > 0 && diff > 4) {
diff = 5;
}
// move paddle
this.paddle.move(diff, 0);
// check if paddle is not outside of the board
if (this.paddle.x < 0) {
this.paddle.x = 0;
} else if (this.paddle.x + this.paddle.width > 400) {
this.paddle.x = 400 - this.paddle.width;
}
};
// updates computer paddle position - ai-based (ai calls it later in a code)
Computer.prototype.ai_update = function (move = 0) {
this.paddle.move(4 * move, 0);
};
// player object
function Player() {
this.paddle = new Paddle(0, 580, 50, 10);
}
// renders player paddle
Player.prototype.render = function () {
this.paddle.render();
};
// updates player paddle position
//Player.prototype.update = Computer.prototype.update;
Player.prototype.update = function () {
for (var key in keysDown) {
var value = Number(key);
if (value == 37) {
this.paddle.move(-4, 0);
} else if (value == 39) {
this.paddle.move(4, 0);
} else {
this.paddle.move(0, 0);
}
}
};
// ball object
function Ball(x, y) {
this.x = x;
this.y = y;
this.x_speed = Math.random()*4+1;
this.y_speed = Math.random()*3+2;
this.player_strikes = false;
this.ai_strikes = false;
}
// renders ball on a table
Ball.prototype.render = function () {
context.beginPath();
context.arc(this.x, this.y, 5, 2 * Math.PI, false);
context.fillStyle = "#ddff59";
context.fill();
};
// updates ball position
Ball.prototype.update = function (paddle1, paddle2, new_turn) {
// update speed and upper/lower point of a ball on a table
this.x += this.x_speed;
this.y += this.y_speed;
var top_x = this.x - 5;
var top_y = this.y - 5;
var bottom_x = this.x + 5;
var bottom_y = this.y + 5;
// check if ball is not outside of a table
// bounce off the side walls
if (this.x - 5 < 0) {
this.x = 5;
this.x_speed = -this.x_speed;
} else if (this.x + 5 > 400) {
this.x = 395;
this.x_speed = -this.x_speed;
}
// if ball hits upper and lower walls - reset ball (score)
if (this.y < 0 || this.y > 600) {
this.x_speed = Math.random()*4+1;
this.y_speed = Math.random()*3+2;
this.x = 200;
this.y = 300;
ai.new_turn();
}
// move ball on a table, update angle and speed, calculate new position
this.player_strikes = false;
this.ai_strikes = false;
if (top_y > 300) {
if (top_y < (paddle1.y + paddle1.height) && bottom_y > paddle1.y && top_x < (paddle1.x + paddle1.width) && bottom_x > paddle1.x) {
this.y_speed = -3;
this.x_speed += (paddle1.x_speed / 2);
this.y += this.y_speed;
this.player_strikes = true;
console.log('player strikes');
}
} else {
if (top_y < (paddle2.y + paddle2.height) && bottom_y > paddle2.y && top_x < (paddle2.x + paddle2.width) && bottom_x > paddle2.x) {
this.y_speed = 3;
this.x_speed += (paddle2.x_speed / 2);
this.y += this.y_speed;
this.ai_strikes = true;
console.log('ai strikes');
}
}
};
// AI object
function AI(){
this.previous_data = null; // data from previous frame
this.training_data = [[], [], []]; // empty training dataset
this.training_batch_data = [[], [], []]; // empty batch (dataset to be added to training data)
this.previous_xs = null; // input data from previus frame
this.turn = 0; // number of turn
this.grab_data = true; // enables/disables data grabbing
this.flip_table = true; // flips table
this.keep_trainig_records = true; // keep some number of training records instead of discardin them each session
this.training_records_to_keep = 100000; // number of training records to keep
this.first_strike = true; // first strike flag (to ommit data)
}
// saves data from current frame of a game
AI.prototype.save_data = function(player, computer, ball){
// return if grabbing is disabled
if(!this.grab_data)
return;
// fresh turn, just fill initial data in
if(this.previous_data == null){
this.previous_data = [player.x, computer.x, ball.x, ball.y];
return;
}
// if ai strikes, start recording data - empty batch
if(ball.ai_strikes){
this.training_batch_data = [[], [], []];
console.log('emtying batch')
}
// create current data object [player_x, computer_x, ball_x, ball_y]
// and embedding index (0 - left, 1 - no move, 2 - right)
data_xs = [player.x, computer.x, ball.x-60, ball.y];
index = (player.x < this.previous_data[0])?0:((player.x == this.previous_data[0])?1:2);
// save data as [...previous data, ...current data]
// result - [old_player_x, old_computer_x, old_ball_x, old_ball_y, player_x, computer_x, ball_x, ball_y]
this.previous_xs = [...this.previous_data, ...data_xs];
// add data to training set depending on index value (depending if that data relates to the move to the left, no move or move to the right)
// only player and ball position
this.training_batch_data[index].push([this.previous_xs[0], this.previous_xs[2], this.previous_xs[3], this.previous_xs[4], this.previous_xs[6], this.previous_xs[7]]);
// set current data as previous data for next frame
this.previous_data = data_xs;
// if player strikes, add batch to training data
if(ball.player_strikes){
if(this.first_strike){
this.first_strike = false;
this.training_batch_data = [[], [], []];
console.log('emtying batch');
}else{
for(i = 0; i < 3; i++)
this.training_data[i].push(...this.training_batch_data[i]);
this.training_batch_data = [[], [], []];
console.log('adding batch');
}
}
}
// runs every turn
AI.prototype.new_turn = function(){
// clean previus data, we are starting fresh
this.first_strike = true;
this.training_batch_data = [[], [], []];
this.previous_data = null;
this.turn++;
console.log('new turn: ' + this.turn);
//computer.ai_plays = !computer.ai_plays;
if(computer.ai_plays){
document.getElementById("playing").innerHTML = "Playing: AI";
}else{
document.getElementById("playing").innerHTML = "Playing: Computer";
}
// after x turn
/*if(this.turn > 9){
// tarin a model
this.train();
// allow ai to play (as we have a trained model)
//computer.ai_plays = true;
// empty training dataset
this.reset();
}*/
}
// empties training data
AI.prototype.reset = function(){
this.previous_data = null;
if(!this.keep_trainig_records)
this.training_data = [[], [], []];
this.turn = 0;
if(computer.ai_plays){
document.getElementById("playing").innerHTML = "Playing: AI";
}else{
document.getElementById("playing").innerHTML = "Playing: Computer";
}
console.log('reset')
console.log('emtying batch')
}
// trains a model
AI.prototype.train = function(){
// first we have to balance a data
console.log('balancing');
document.getElementById("playing").innerHTML = "Training";
// trim data and find minimum number of training records in data for all 3 embeddings
if(this.keep_trainig_records){
for(i = 0; i < 3; i++){
if(this.training_data[i].length > this.training_records_to_keep)
this.training_data[i] = this.training_data[i].slice(
Math.max(0, this.training_data[i].length - this.training_records_to_keep),
this.training_data[i].length
);
}
}
len = Math.min(this.training_data[0].length, this.training_data[1].length, this.training_data[2].length);
console.log(this.training_data);
if(!len){
console.log('no data to train on');
return;
}
data_xs = [];
data_ys = [];
// now we need to trim data so every embedding will contain exactly the same amount of training records
// than randomize that data
// and create embedding records one embedding record for every input data record
// finally add training data records and embedding records to common tables (for training)
// tf.fit() will do final data shuffle for us
for(i = 0; i < 3; i++){
data_xs.push(...this.training_data[i].slice(0, len)
.sort(()=>Math.random()-0.5).sort(()=>Math.random()-0.5)); // trims training data to 'len' length and shuffle it
data_ys.push(...Array(len).fill([i==0?1:0, i==1?1:0, i==2?1:0])); // creates 'len' number records of embedding data
// either [1, 0 0] for left, [0, 1, 0] - for no move
// and [0, 0, 1] for right (depending in index if training data)
}
//console.log(data_xs);
//console.log(data_ys);
document.createElement("playing").innerHTML = "Training: "+data_xs.length+" records";
console.log('training-1');
// create tensor from
const xs = tf.tensor(data_xs);
const ys = tf.tensor(data_ys);
// "crative" way of running asynchronous code in a synchronous-like manner
(async function() {
console.log('training-2');
// train a model
let result = await model.fit(xs, ys, {
batchSize: 32,
epochs: 1,
shuffle: true,
validationSplit: 0.1,
callbacks: {
// print batch stats
onBatchEnd: async (batch, logs) => {
console.log("Step "+batch+", loss: "+logs.loss.toFixed(5)+", acc: "+logs.acc.toFixed(5));
},
},
});
// and save it in a local storage (for later use)
await model.save('indexeddb://my-model-1');
// print model and validation stats
console.log("Model: loss: "+result.history.loss[0].toFixed(5)+", acc: "+result.history.acc[0].toFixed(5));
console.log("Validation: loss: "+result.history.val_loss[0].toFixed(5)+", acc: "+result.history.val_acc[0].toFixed(5));
}());
console.log('trained');
}
// inferences a move
AI.prototype.predict_move = function(){
// but only for 2+ frame of a game (we need data from previous frame as well)
if(this.previous_xs != null){
// flip table so ai will see it from player's perspective
// and try to mimic his gameplay
// also use ionly ai's paddle positions
data_xs = [
width - this.previous_xs[1], width - this.previous_xs[2], height - this.previous_xs[3],
width - this.previous_xs[5], width - this.previous_xs[6], height - this.previous_xs[7]
];
// predict move
prediction = model.predict(tf.tensor([data_xs]));
// argmax will return embeddingL 0, 1 or 2, we need -1, 0 or 1 (left, no move, right) - decrement it and return
// also we actually need to flip that prediction, as ai plays on top (upside-down)
//return -(tf.argMax(prediction, 1).dataSync()-1);
return -(tf.argMax(prediction, 1).dataSync()-1);
}
}
// add canvas
document.body.appendChild(canvas);
// init whole code
init();
// arrow keypress events
window.addEventListener("keydown", function (event) {
keysDown[event.keyCode] = true;
});
window.addEventListener("keyup", function (event) {
delete keysDown[event.keyCode];
});