JS_Projects/linearRegressionTF/sketch.js
Dawid Pietrykowski 25a12f221e Initial commit
2022-11-09 23:34:49 +01:00

122 lines
2.8 KiB
JavaScript
Executable File

var xs = [];
var ys = [];
var vars = [];
var deg = 7;
const optimizer = tf.train.adam(0.4);
function setup() {
createCanvas(1920, 850);
background(51);
frameRate(144);
//learningRateSlider = createSlider(1, 100, 35);
res = createSlider(0, 100, 10);
// a = createSlider(0, 100, 100);
// v = createSlider(0, 100, 0);
// for(i = 0; i<deg; i++){
// vars[i] = random(-1,1);
// }
a = tf.variable(tf.scalar(random(-1,1)));
b = tf.variable(tf.scalar(random(-1,1)));
c = tf.variable(tf.scalar(random(-1,1)));
d = tf.variable(tf.scalar(random(-1,1)));
e = tf.variable(tf.scalar(random(-1,1)));
f = tf.variable(tf.scalar(random(-1,1)));
g = tf.variable(tf.scalar(random(-1,1)));
tf.setBackend('cpu');
//translate(-height,0);
}
function draw() {
clear();
background(51);
//const learningRate = learningRateSlider.value() * 0.01;
if (xs.length >= 1) {
const ystf = tf.tensor1d(ys);
optimizer.minimize(() => loss(predict(xs), ystf));
//print(loss(predict(xs), ystf).bufferSync().get(0));
drawLine(res.value());
tf.dispose(ystf);
}
drawPoins();
}
function predict(xst) {
//y = ax3 + bx2 + cx + d
x = tf.tensor1d(xst);
// result = tf.tensor1d([0]).add(vars[0]);
// for(let i = 1; i<deg; i++){
// result.add(x.pow(tf.scalar(i)).mul(vars[i]));
// }
result = (x.pow(tf.scalar(6)).mul(a))
.add(x.pow(tf.scalar(5)).mul(b))
.add(x.pow(tf.scalar(4)).mul(c))
.add(x.pow(tf.scalar(3)).mul(d))
.add(x.square().mul(e))
.add(x.mul(f))
.add(g);
return result;
}
function loss(pred, labels) {
return pred.sub(labels).square().mean();
}
function drawLine(resolution) {
tf.tidy(() => {
push();
stroke(255);
strokeWeight(4);
/*
let r = 1 / resolution;
let y = predict([0]).bufferSync();
y0 = y.get(0);
tf.dispose(y);
for (let i = 0; i < 4; i++) {
let y = predict([r * (i + 1)]).bufferSync();
y1 = y.get(0);
tf.dispose(y);
line(mW(r * (i)), mH(y0), mW(r * (i + 1)), mH(y1));
y0 = y1;
}*/
let yss;
noFill();
beginShape();
for (let i = 0; i <= resolution; i++) {
let y = predict([(1/resolution)*i]).bufferSync();
yss = y.get(0);
curveVertex(mW((1/resolution)*i),mH(yss));
tf.dispose(y);
}
endShape();
pop();
});
}
function mouseClicked() {
if (mouseY <= height && mouseX <= width) {
xs.push(map(mouseX, 0, width, 0, 1));
ys.push(map(mouseY, 0, height, 1, 0));
// xs[0] = map(mouseX, 0, width, 0, 1);
// ys[0] = map(mouseY, 0, height, 1, 0);
}
}
function mH(val) {
return map(val, 0, 1, height, 0);
}
function mW(val) {
return map(val, 0, 1, 0, width);
}
function drawPoins() {
stroke(230);
strokeWeight(5);
for (let i = 0; i < xs.length; i++) {
point(map(xs[i], 0, 1, 0, width), map(ys[i], 0, 1, height, 0));
}
}