require "js" require "base64" include DNN::Models include DNN::Layers include DNN::Optimizers include DNN::Losses include DNN::Loaders Window = JS.global Document = JS.global[:document] $trained_mnist_params = nil JS::Object.undef_method(:then) def create_model model = Sequential.new model << InputLayer.new([28, 28, 1]) model << Flatten.new model << Dense.new(64) model << BatchNormalization.new model << ReLU.new model << Dense.new(64) model << BatchNormalization.new model << ReLU.new model << Dense.new(64) model << BatchNormalization.new model << ReLU.new model << Dense.new(10) model.setup(Adam.new, SoftmaxCrossEntropy.new) model.add_lambda_callback(:after_train) do accuracy, loss = model.evaluate($x_test, $y_test) BrowserConsole.puts "accuracy: #{accuracy}" BrowserConsole.puts "loss: #{loss}" end model end def start_training(model, x_train, y_train, x_test, y_test) trainer = ModelTrainer.new(model) trainer.start_train(x_train, y_train, 3, batch_size: 128, test: [x_test, y_test], io: BrowserConsole) func = -> do trainer.update if trainer.training? JS.global.call(:setTimeout, JS.try_convert(func)) else BrowserConsole.puts("End MLP model training") end end JS.global.call(:setTimeout, JS.try_convert(func)) end def load_conv2d_model model = ConvNet.create([28, 28, 1]) model.predict1(Numo::SFloat.zeros(28, 28, 1)) loader = MarshalLoader.new(model) loader.load_bin($trained_mnist_params) model end def update_result(classification) str = "" 10.times do |i| str += "#{i}: #{(classification[i] * 100).round(2)}%
" end $result_area[:innerHTML] = str end def main $model = create_model Document.write(<<-EOS) Go to github


EOS $draw_canvas = Document.getElementById("draw") $draw_context = $draw_canvas.getContext("2d") $draw_context.fillRect(0, 0, $draw_canvas[:width], $draw_canvas[:height]) $judge_button = Document.getElementById("judge") $clear_button = Document.getElementById("clear") $start_training_button = Document.getElementById("startTraining") $load_model_button = Document.getElementById("loadModel") $result_area = Document.getElementById("result") $log_field = Document.getElementById("logField") BrowserConsole.dom_element = $log_field $judge_button.addEventListener("click") do canvas = Document.createElement("canvas"); canvas[:width] = 28 canvas[:height] = 28 ctx = canvas.getContext("2d") ctx.drawImage($draw_canvas, 0, 0, canvas[:width], canvas[:height]) data = ctx.getImageData(0, 0, canvas[:width], canvas[:height])[:data] x = Numo::UInt8.cast(data.to_s.split(",").map { |s| s.to_i }).reshape(28, 28, 4) x = Numo::SFloat.cast(x[true, true, 0..2]) / 255.0 x = x.mean(axis: 2, keepdims: true) y = $model.predict1(x) update_result(y) end $clear_button.addEventListener("click") do $draw_context[:fillStyle] = "black" $draw_context.fillRect(0, 0, $draw_canvas[:width], $draw_canvas[:height]) $result_area[:innerHTML] = "" end $start_training_button.addEventListener("click") do Document[:body].removeChild(Document.getElementById("trainOrLoad")) start_training($model, $x_train, $y_train, $x_test, $y_test) end $load_model_button.addEventListener("click") do Document[:body].removeChild(Document.getElementById("trainOrLoad")) $model = load_conv2d_model BrowserConsole.puts("Load conv model") end $mouse_down = false Window.addEventListener("mousedown") do |e| $mouse_down = true end Window.addEventListener("mouseup") do |e| $mouse_down = false end $draw_canvas.addEventListener("mousemove") do |e| if $mouse_down rect = e[:target].getBoundingClientRect x = e[:clientX].to_s.to_i - 10 - rect[:left].to_s.to_i y = e[:clientY].to_s.to_i - 10 - rect[:top].to_s.to_i $draw_context[:fillStyle] = "white" $draw_context.fillRect(x, y, 20, 20) end end end def boot JS.global.fetch("mnist_data.marshal.txt").then do |response| response.text.then do |text| (x_train, y_train, x_test, y_test) = Marshal.load(Base64.decode64(text.to_s)) x_train = Numo::SFloat.cast(x_train) / 255 x_test = Numo::SFloat.cast(x_test) / 255 y_train = DNN::Utils.to_categorical(y_train, 10, Numo::SFloat) y_test = DNN::Utils.to_categorical(y_test, 10, Numo::SFloat) $x_train = x_train $y_train = y_train $x_test = x_test $y_test = y_test JS.global.fetch("trained_mnist_params.marshal.txt").then do |response| response.text.then do |text| $trained_mnist_params = Base64.decode64(text.to_s) Document[:body].removeChild(Document.getElementById("nowLoading")) main end end end end end boot