110 lines
1.9 KiB
JavaScript
110 lines
1.9 KiB
JavaScript
|
|
||
|
|
||
|
var history;
|
||
|
let resolution = 50;
|
||
|
var xs, ys, nn, result;
|
||
|
let x = [];
|
||
|
let y = [];
|
||
|
let predictionArray = [];
|
||
|
let res = 1 / resolution;
|
||
|
|
||
|
var inX = [
|
||
|
[0, 0],
|
||
|
[0, 1],
|
||
|
[1, 0],
|
||
|
[1, 1]
|
||
|
];
|
||
|
var inY = [
|
||
|
[0],
|
||
|
[1],
|
||
|
[1],
|
||
|
[0]
|
||
|
];
|
||
|
|
||
|
function setup() {
|
||
|
createCanvas(900, 900);
|
||
|
background(50);
|
||
|
frameRate(10);
|
||
|
//learningRateSlider = createSlider(1, 100, 35);
|
||
|
//a = tf.variable(tf.scalar(random(-1,1)));
|
||
|
tf.setBackend('cpu');
|
||
|
//trainNN();
|
||
|
//noLoop();
|
||
|
nn = tf.sequential();
|
||
|
xs = tf.tensor2d(inX);
|
||
|
ys = tf.tensor2d(inY);
|
||
|
const optimizer = tf.train.sgd(1);
|
||
|
var hidden = tf.layers.dense({
|
||
|
units: 4,
|
||
|
inputShape: [2],
|
||
|
activation: 'sigmoid'
|
||
|
});
|
||
|
var output = tf.layers.dense({
|
||
|
units: 1,
|
||
|
activation: 'sigmoid'
|
||
|
});
|
||
|
nn.add(hidden);
|
||
|
nn.add(output);
|
||
|
|
||
|
nn.compile({
|
||
|
optimizer: optimizer,
|
||
|
loss: 'meanSquaredError'
|
||
|
});
|
||
|
|
||
|
for (let i = 0; i < 1; i += res) {
|
||
|
x.push(i);
|
||
|
y.push(i);
|
||
|
}
|
||
|
for (let i = 0; i < resolution; i++) {
|
||
|
for (let j = 0; j < resolution; j++) {
|
||
|
predictionArray.push([i / resolution, j / resolution]);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
async function trainNN() {
|
||
|
let o = await nn.fit(xs, ys, {
|
||
|
epochs: 100,
|
||
|
shuffle: true
|
||
|
});
|
||
|
print(o.history.loss[0]);
|
||
|
//o.dispose();
|
||
|
}
|
||
|
|
||
|
function draw() {
|
||
|
clear();
|
||
|
background(50);
|
||
|
|
||
|
trainNN().then(drawgrid());
|
||
|
|
||
|
// tf.tidy(() => {
|
||
|
// trainNN();
|
||
|
// });
|
||
|
|
||
|
//print(tf.memory().numTensors);
|
||
|
}
|
||
|
|
||
|
function drawgrid() {
|
||
|
tf.tidy(() => {
|
||
|
let td = tf.tensor2d(predictionArray);
|
||
|
result = nn.predictOnBatch(td).bufferSync().values;
|
||
|
//let toFill;
|
||
|
noStroke();
|
||
|
for (let i = 0; i < resolution; i ++){
|
||
|
for (let j = 0; j < resolution; j ++){
|
||
|
fill(result[(i*resolution)+j]*255);
|
||
|
square(mW(j/resolution),mH(i/resolution)-height/resolution,mW(res));
|
||
|
}
|
||
|
}
|
||
|
//td.dispose();
|
||
|
//result.dispose();
|
||
|
});
|
||
|
}
|
||
|
|
||
|
function mH(val) {
|
||
|
return map(val, 0, 1, height, 0);
|
||
|
}
|
||
|
|
||
|
function mW(val) {
|
||
|
return map(val, 0, 1, 0, width);
|
||
|
}
|