var history; let resolution = 50; var xs, ys, nn, result; var inX = []; var inY = []; let x = []; let y = []; var xjson, yjson; var cl; var done = 0; const txt = 'Vestibulum porttitor convallis sem ut dictum. Sed auctor, libero ut venenatis tristique, urna tellus ornare nisi, sed imperdiet justo turpis laoreet libero. Nullam blandit dignissim elit, et pretium sem feugiat nec. Aliquam enim lectus, feugiat in mi in, semper suscipit odio. Vestibulum ante ipsum primis in faucibus orci luctus et.'; var dataTT; var training = false; var stop = false; let weights = []; let biases = []; function setup() { createCanvas(1920, 933); background(50); frameRate(1); //learningRateSlider = createSlider(1, 100, 35); //a = tf.variable(tf.scalar(random(-1,1))); dataTT = new dataToTrain(); tf.setBackend('cpu'); nn = tf.sequential(); const optimizer = tf.train.adamax(); var hidden1 = tf.layers.dense({ units: 6, inputShape: [3], activation: 'sigmoid', name: 'hidden1' }); var hidden2 = tf.layers.dense({ units: 6, activation: 'sigmoid', name: 'hidden2' }); var output = tf.layers.dense({ units: 1, activation: 'sigmoid', name: 'output' }); nn.add(hidden1); nn.add(hidden2); nn.add(output); nn.compile({ optimizer: optimizer, loss: 'meanSquaredError' }); cl = color(random(0, 255), random(0, 255), random(0, 255)); textSize(36); textAlign(CENTER, CENTER); retData(); xs = tf.tensor2d(inX); ys = tf.tensor1d(inY); trainNN(); } function draw() { if (training) { //trainNN(); } //print(tf.memory().numTensors); } async function trainNN() { //print(o.history.loss[0]); if (!stop) { let o = await nn.fit(xs, ys, { epochs: 10, shuffle: true }); //print(o.history.loss[0]); clear(); text(floor(o.history.loss[0]*10000)/10000,width/5,width/10) setTimeout(trainNN, 0); } //o.dispose(); } function keyPressed() { if (keyCode === LEFT_ARROW) { inX.push(0); inY.push([cl.levels[0], cl.levels[1], cl.levels[2]]); cl = color(random(0, 255), random(0, 255), random(0, 255)); done++; } else if (keyCode === RIGHT_ARROW) { inX.push(1); inY.push([cl.levels[0], cl.levels[1], cl.levels[2]]); cl = color(random(0, 255), random(0, 255), random(0, 255)); done++; } else if (keyCode === UP_ARROW) { saveData(); } else if (keyCode === DOWN_ARROW) { //retData(); } } function getWeights(){ //let all = nn.getWeights()[0].arraySync(); weights = []; biases = []; for(let i = 0; i<6; i++){ if(i%2==0){ weights.push(nn.getWeights()[i].arraySync()); } if(i%2!=0){ biases.push(nn.getWeights()[i].arraySync()); } } } function saveData() { xjson = JSON.stringify(inX); yjson = JSON.stringify(inY); download(xjson, 'xs.txt', 'text/plain'); download(yjson, 'ys.txt', 'text/plain'); } function retData() { inX = dataTT.xs; inY = dataTT.ys; } function download(content, fileName, contentType) { var a = document.createElement("a"); var file = new Blob([content], { type: contentType }); a.href = URL.createObjectURL(file); a.download = fileName; a.click(); } function drawText() { background(cl); fill(0); rectMode(CENTER); //text(txt,width/4, height/2,width/8,(width/8)*3); text(txt, width / 8 * 2, height / 2, (width / 8) * 3, height / 3); fill(255); //text(txt,(width/4)*3, height/2,(width/8)*5,(width/8)*7); text(txt, width / 8 * 6, height / 2, (width / 8) * 3, height / 3); text(done, width / 2, height / 1.2, 200, 300); } function drawColor(){ cl = color(random(0, 255), random(0, 255), random(0, 255)); let outcome = nn.predict(tf.tensor2d([[cl.levels[0], cl.levels[1], cl.levels[2]]])).bufferSync().get(0); rectMode(CENTER); if(outcome<0.5) fill(0); else fill(255); background(cl); text(txt, width / 2, height / 2, (width / 8) * 5 , height / 3); setTimeout(drawColor, 2000); }