JS_Projects/carDrivingNeuralNet/network.js

33 lines
840 B
JavaScript
Raw Normal View History

2022-11-09 23:34:49 +01:00
class NeuralNetwork {
constructor() {
this.nn = tf.sequential();
this.hidden1 = tf.layers.dense({
units: 6,
inputShape: [9]
});
this.hidden2 = tf.layers.dense({
units: 6
});
this.output = tf.layers.dense({
units: 2
});
this.nn.add(this.hidden1)
this.nn.add(this.hidden2);
this.nn.add(this.output);
this.nn.compile({
optimizer: tf.train.adam(1),
loss: 'meanSquaredError'
});
this.predictions = [1,0];
}
drive(inputs) {
return this.nn.predict(tf.tensor([inputs])).arraySync()[0];
}
prediction(data){
let temp = this.nn.predict(tf.tensor([data]));
this.predictions = temp.arraySync()[0];
tf.dispose(temp);
}
}