Run this notebook

Use Livebook to open this notebook and explore new ideas.

It is easy to get started, on your machine or the cloud.

Click below to open and run it in your Livebook at .

(or change your Livebook location)

<!-- livebook:{"persist_outputs":true} --> # Writing custom metrics ```elixir Mix.install([ {:axon, github: "elixir-nx/axon"}, {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true} ]) ``` <!-- livebook:{"output":true} --> ``` :ok ``` ## Writing custom metrics When passing an atom to `Axon.Loop.metric/5`, Axon dispatches the function to a built-in function in `Axon.Metrics`. If you find you'd like to use a metric that does not exist in `Axon.Metrics`, you can define a custom function: ```elixir defmodule CustomMetric do import Nx.Defn defn my_weird_metric(y_true, y_pred) do Nx.atan2(y_true, y_pred) |> Nx.sum() end end ``` <!-- livebook:{"output":true} --> ``` {:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:my_weird_metric, 2}} ``` Then you can pass that directly to `Axon.Loop.metric/5`. You must provide a name for your custom metric: ```elixir model = Axon.input("data") |> Axon.dense(8) |> Axon.relu() |> Axon.dense(4) |> Axon.relu() |> Axon.dense(1) loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd) |> Axon.Loop.metric(&CustomMetric.my_weird_metric/2, "my weird metric") ``` <!-- livebook:{"output":true} --> ``` #Axon.Loop< handlers: %{ completed: [], epoch_completed: [ {#Function<23.77614421/1 in Axon.Loop.log/5>, #Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ {#Function<23.77614421/1 in Axon.Loop.log/5>, #Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, metrics: %{ "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, #Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>}, "my weird metric" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, &CustomMetric.my_weird_metric/2} }, ... > ``` Then when running, Axon will invoke your custom metric function and accumulate it with the given aggregator: ```elixir train_data = Stream.repeatedly(fn -> xs = Nx.random_normal({8, 1}) ys = Nx.sin(xs) {xs, ys} end) Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` <!-- livebook:{"output":true} --> ``` Epoch: 0, Batch: 1000, loss: 0.0468431 my weird metric: -5.7462921 ``` <!-- livebook:{"output":true} --> ``` %{ "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] [0.011475208215415478, 0.23035769164562225, 0.01538881566375494, 0.08167446404695511, 0.23642019927501678, 0.10298296064138412, 0.20279639959335327, -0.18916435539722443] >, "kernel" => #Nx.Tensor< f32[1][8] [ [0.7426201105117798, 0.734136700630188, -0.5648708343505859, -0.5230435132980347, 0.3056533932685852, 0.3383721709251404, -0.3518844544887543, -0.19460521638393402] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] [0.2185358852148056, 0.23043134808540344, 0.0, 0.2650437355041504] >, "kernel" => #Nx.Tensor< f32[8][4] [ [0.19164204597473145, -0.26440876722335815, 0.060297321528196335, 0.004777891095727682], [0.019263261929154396, -0.6267783045768738, -0.33454063534736633, 0.33268266916275024], [-0.18489953875541687, 0.4653063714504242, -0.6056118607521057, -0.046012550592422485], [0.5975558161735535, -0.237883061170578, -0.6522921919822693, 0.019332828000187874], [-0.7424253225326538, 0.593705952167511, 0.2551117241382599, 0.26270362734794617], [0.018434584140777588, 0.15290242433547974, 0.08793036639690399, 0.1839984804391861], [0.6048195958137512, -0.20294713973999023, -0.694927990436554, -0.45577046275138855], [-0.628790020942688, 0.21741150319576263, -0.08936657756567001, 0.6170362234115601] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] [-0.03722470998764038] >, "kernel" => #Nx.Tensor< f32[4][1] [ [-0.7919473648071289], [-0.4341854751110077], [-0.39114490151405334], [0.9605273008346558] ] > } } ``` While the metric defaults are designed with supervised training loops in mind, they can be used for much more flexible purposes. By default, metrics look for the fields `:y_true` and `:y_pred` in the given loop's step state. They then apply the given metric function on those inputs. You can also define metrics which work on other fields. For example you can track the running average of a given parameter with a metric just by defining a custom output transform: ```elixir model = Axon.input("data") |> Axon.dense(8) |> Axon.relu() |> Axon.dense(4) |> Axon.relu() |> Axon.dense(1) output_transform = fn %{model_state: model_state} -> [model_state["dense_0"]["kernel"]] end loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd) |> Axon.Loop.metric(&Nx.mean/1, "dense_0_kernel_mean", :running_average, output_transform) |> Axon.Loop.metric(&Nx.variance/1, "dense_0_kernel_var", :running_average, output_transform) ``` <!-- livebook:{"output":true} --> ``` #Axon.Loop< handlers: %{ completed: [], epoch_completed: [ {#Function<23.77614421/1 in Axon.Loop.log/5>, #Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ {#Function<23.77614421/1 in Axon.Loop.log/5>, #Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, metrics: %{ "dense_0_kernel_mean" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, &Nx.mean/1}, "dense_0_kernel_var" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, &Nx.variance/1}, "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, #Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>} }, ... > ``` Axon will apply your custom output transform to the loop's step state and forward the result to your custom metric function: ```elixir train_data = Stream.repeatedly(fn -> xs = Nx.random_normal({8, 1}) ys = Nx.sin(xs) {xs, ys} end) Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` <!-- livebook:{"output":true} --> ``` Epoch: 0, Batch: 1000, dense_0_kernel_mean: 0.0807205 dense_0_kernel_var: 0.1448047 loss: 0.0626600 ``` <!-- livebook:{"output":true} --> ``` %{ "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] [-0.14429236948490143, 0.3176318109035492, 0.0036036474630236626, 0.01434470433741808, 0.21225003898143768, -0.1406097412109375, 0.32469284534454346, -0.18893203139305115] >, "kernel" => #Nx.Tensor< f32[1][8] [ [0.2918722331523895, -0.44978663325309753, -0.28219935297966003, -0.10681337863206863, 0.5192054510116577, 0.312747985124588, -0.15127503871917725, 0.5638187527656555] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] [0.0, -0.003864143043756485, 0.5194356441497803, 0.028363214805722237] >, "kernel" => #Nx.Tensor< f32[8][4] [ [-0.6123268008232117, 0.22753892838954926, 0.12077417969703674, 0.4875330626964569], [-0.5840837359428406, 0.2259720116853714, 0.4917944371700287, 0.22638437151908875], [-0.22699439525604248, -0.6744257807731628, -0.2907045781612396, 0.35300591588020325], [-0.16367988288402557, -0.5971682071685791, -0.39346548914909363, 0.5823913812637329], [-0.5512545704841614, -0.6812713742256165, -0.5777145624160767, -0.653957188129425], [-0.23620283603668213, -0.47966212034225464, -0.273225873708725, 0.3827615976333618], [-0.5591338276863098, -0.1730434000492096, 0.25726518034935, 0.7179149389266968], [0.3902169167995453, 0.6351881623268127, -0.602277398109436, 0.40137141942977905] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] [0.824558675289154] >, "kernel" => #Nx.Tensor< f32[4][1] [ [0.9618374109268188], [-0.028266794979572296], [-1.1059081554412842], [-0.7398673892021179] ] > } } ``` You can also define custom accumulation functions. Axon has definitions for computing running averages and running sums; however, you might find you need something like an exponential moving average: ```elixir defmodule CustomAccumulator do import Nx.Defn defn running_ema(acc, obs, _i, opts \\ []) do opts = keyword!(opts, alpha: 0.9) obs * opts[:alpha] + acc * (1 - opts[:alpha]) end end ``` <!-- livebook:{"output":true} --> ``` {:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:running_ema, 4}} ``` Your accumulator must be an arity-3 function which accepts the current accumulated value, the current observation, and the current iteration and returns the aggregated metric. You can pass a function direct as an accumulator in your metric: ```elixir model = Axon.input("data") |> Axon.dense(8) |> Axon.relu() |> Axon.dense(4) |> Axon.relu() |> Axon.dense(1) output_transform = fn %{model_state: model_state} -> [model_state["dense_0"]["kernel"]] end loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd) |> Axon.Loop.metric( &Nx.mean/1, "dense_0_kernel_ema_mean", &CustomAccumulator.running_ema/3, output_transform ) ``` <!-- livebook:{"output":true} --> ``` #Axon.Loop< handlers: %{ completed: [], epoch_completed: [ {#Function<23.77614421/1 in Axon.Loop.log/5>, #Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>} ], epoch_halted: [], epoch_started: [], halted: [], iteration_completed: [ {#Function<23.77614421/1 in Axon.Loop.log/5>, #Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>} ], iteration_started: [], started: [] }, metrics: %{ "dense_0_kernel_ema_mean" => {#Function<12.77614421/3 in Axon.Loop.build_metric_fn/3>, &Nx.mean/1}, "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>, #Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>} }, ... > ``` Then when you run the loop, Axon will use your custom accumulator: ```elixir train_data = Stream.repeatedly(fn -> xs = Nx.random_normal({8, 1}) ys = Nx.sin(xs) {xs, ys} end) Axon.Loop.run(loop, train_data, %{}, iterations: 1000) ``` <!-- livebook:{"output":true} --> ``` Epoch: 0, Batch: 1000, dense_0_kernel_ema_mean: 0.2137861 loss: 0.0709054 ``` <!-- livebook:{"output":true} --> ``` %{ "dense_0" => %{ "bias" => #Nx.Tensor< f32[8] [0.08160790055990219, -0.21322371065616608, -0.1431925743818283, 0.2848915755748749, -0.007875560782849789, 0.3923396170139313, -0.04444991424679756, 0.23083189129829407] >, "kernel" => #Nx.Tensor< f32[1][8] [ [-0.6269387006759644, 0.3289071023464203, 0.19450749456882477, 0.7400281429290771, 0.23878233134746552, 0.36140456795692444, 0.10503113269805908, 0.3685782253742218] ] > }, "dense_1" => %{ "bias" => #Nx.Tensor< f32[4] [0.2350393682718277, 0.06712433695793152, -0.03675961494445801, -0.06366443634033203] >, "kernel" => #Nx.Tensor< f32[8][4] [ [-0.35826751589775085, -0.10699580609798431, -0.3681609034538269, 0.08517063409090042], [-0.7694831490516663, 0.13644370436668396, -0.2390032261610031, 0.6069303154945374], [-0.6424086689949036, 0.13374455273151398, -0.35404452681541443, 0.6343701481819153], [-0.09528166800737381, 0.7048070430755615, 0.13699916005134583, 0.6482889652252197], [-0.08044164627790451, 0.010588583536446095, 0.11140558868646622, 0.33911004662513733], [0.7361723780632019, 0.757600724697113, -0.0011848200811073184, 0.2799053192138672], [0.3472788631916046, -0.5225644111633301, 0.04859891161322594, -0.4931156039237976], [0.09371320903301239, 0.5478940606117249, 0.5831385254859924, -0.21019525825977325] ] > }, "dense_2" => %{ "bias" => #Nx.Tensor< f32[1] [-0.835706889629364] >, "kernel" => #Nx.Tensor< f32[4][1] [ [1.0109968185424805], [0.574639618396759], [-0.01302765030413866], [-0.008134203962981701] ] > } } ```
See source

Have you already installed Livebook?

If you already installed Livebook, you can configure the default Livebook location where you want to open notebooks.
Livebook up Checking status We can't reach this Livebook (but we saved your preference anyway)
Run notebook

Not yet? Install Livebook in just a minute

Livebook is open source, free, and ready to run anywhere.

Run on your machine

with Livebook Desktop

Run in the cloud

on select platforms

To run on Linux, Docker, embedded devices, or Elixir’s Mix, check our README.

PLATINUM SPONSORS
SPONSORS
Code navigation with go to definition of modules and functions Read More ×