<!-- livebook:{"persist_outputs":true} -->
# Writing custom metrics
```elixir
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}
])
```
<!-- livebook:{"output":true} -->
```
:ok
```
## Writing custom metrics
When passing an atom to `Axon.Loop.metric/5`, Axon dispatches the function to a built-in function in `Axon.Metrics`. If you find you'd like to use a metric that does not exist in `Axon.Metrics`, you can define a custom function:
```elixir
defmodule CustomMetric do
import Nx.Defn
defn my_weird_metric(y_true, y_pred) do
Nx.atan2(y_true, y_pred) |> Nx.sum()
end
end
```
<!-- livebook:{"output":true} -->
```
{:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:my_weird_metric, 2}}
```
Then you can pass that directly to `Axon.Loop.metric/5`. You must provide a name for your custom metric:
```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(&CustomMetric.my_weird_metric/2, "my weird metric")
```
<!-- livebook:{"output":true} -->
```
#Axon.Loop<
handlers: %{
completed: [],
epoch_completed: [
{#Function<23.77614421/1 in Axon.Loop.log/5>,
#Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<23.77614421/1 in Axon.Loop.log/5>,
#Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
metrics: %{
"loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
#Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>},
"my weird metric" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
&CustomMetric.my_weird_metric/2}
},
...
>
```
Then when running, Axon will invoke your custom metric function and accumulate it with the given aggregator:
```elixir
train_data =
Stream.repeatedly(fn ->
xs = Nx.random_normal({8, 1})
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
```
<!-- livebook:{"output":true} -->
```
Epoch: 0, Batch: 1000, loss: 0.0468431 my weird metric: -5.7462921
```
<!-- livebook:{"output":true} -->
```
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.011475208215415478, 0.23035769164562225, 0.01538881566375494, 0.08167446404695511, 0.23642019927501678, 0.10298296064138412, 0.20279639959335327, -0.18916435539722443]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[0.7426201105117798, 0.734136700630188, -0.5648708343505859, -0.5230435132980347, 0.3056533932685852, 0.3383721709251404, -0.3518844544887543, -0.19460521638393402]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.2185358852148056, 0.23043134808540344, 0.0, 0.2650437355041504]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.19164204597473145, -0.26440876722335815, 0.060297321528196335, 0.004777891095727682],
[0.019263261929154396, -0.6267783045768738, -0.33454063534736633, 0.33268266916275024],
[-0.18489953875541687, 0.4653063714504242, -0.6056118607521057, -0.046012550592422485],
[0.5975558161735535, -0.237883061170578, -0.6522921919822693, 0.019332828000187874],
[-0.7424253225326538, 0.593705952167511, 0.2551117241382599, 0.26270362734794617],
[0.018434584140777588, 0.15290242433547974, 0.08793036639690399, 0.1839984804391861],
[0.6048195958137512, -0.20294713973999023, -0.694927990436554, -0.45577046275138855],
[-0.628790020942688, 0.21741150319576263, -0.08936657756567001, 0.6170362234115601]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.03722470998764038]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.7919473648071289],
[-0.4341854751110077],
[-0.39114490151405334],
[0.9605273008346558]
]
>
}
}
```
While the metric defaults are designed with supervised training loops in mind, they can be used for much more flexible purposes. By default, metrics look for the fields `:y_true` and `:y_pred` in the given loop's step state. They then apply the given metric function on those inputs. You can also define metrics which work on other fields. For example you can track the running average of a given parameter with a metric just by defining a custom output transform:
```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
output_transform = fn %{model_state: model_state} ->
[model_state["dense_0"]["kernel"]]
end
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(&Nx.mean/1, "dense_0_kernel_mean", :running_average, output_transform)
|> Axon.Loop.metric(&Nx.variance/1, "dense_0_kernel_var", :running_average, output_transform)
```
<!-- livebook:{"output":true} -->
```
#Axon.Loop<
handlers: %{
completed: [],
epoch_completed: [
{#Function<23.77614421/1 in Axon.Loop.log/5>,
#Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<23.77614421/1 in Axon.Loop.log/5>,
#Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
metrics: %{
"dense_0_kernel_mean" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
&Nx.mean/1},
"dense_0_kernel_var" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
&Nx.variance/1},
"loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
#Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>}
},
...
>
```
Axon will apply your custom output transform to the loop's step state and forward the result to your custom metric function:
```elixir
train_data =
Stream.repeatedly(fn ->
xs = Nx.random_normal({8, 1})
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
```
<!-- livebook:{"output":true} -->
```
Epoch: 0, Batch: 1000, dense_0_kernel_mean: 0.0807205 dense_0_kernel_var: 0.1448047 loss: 0.0626600
```
<!-- livebook:{"output":true} -->
```
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.14429236948490143, 0.3176318109035492, 0.0036036474630236626, 0.01434470433741808, 0.21225003898143768, -0.1406097412109375, 0.32469284534454346, -0.18893203139305115]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[0.2918722331523895, -0.44978663325309753, -0.28219935297966003, -0.10681337863206863, 0.5192054510116577, 0.312747985124588, -0.15127503871917725, 0.5638187527656555]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.0, -0.003864143043756485, 0.5194356441497803, 0.028363214805722237]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.6123268008232117, 0.22753892838954926, 0.12077417969703674, 0.4875330626964569],
[-0.5840837359428406, 0.2259720116853714, 0.4917944371700287, 0.22638437151908875],
[-0.22699439525604248, -0.6744257807731628, -0.2907045781612396, 0.35300591588020325],
[-0.16367988288402557, -0.5971682071685791, -0.39346548914909363, 0.5823913812637329],
[-0.5512545704841614, -0.6812713742256165, -0.5777145624160767, -0.653957188129425],
[-0.23620283603668213, -0.47966212034225464, -0.273225873708725, 0.3827615976333618],
[-0.5591338276863098, -0.1730434000492096, 0.25726518034935, 0.7179149389266968],
[0.3902169167995453, 0.6351881623268127, -0.602277398109436, 0.40137141942977905]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.824558675289154]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.9618374109268188],
[-0.028266794979572296],
[-1.1059081554412842],
[-0.7398673892021179]
]
>
}
}
```
You can also define custom accumulation functions. Axon has definitions for computing running averages and running sums; however, you might find you need something like an exponential moving average:
```elixir
defmodule CustomAccumulator do
import Nx.Defn
defn running_ema(acc, obs, _i, opts \\ []) do
opts = keyword!(opts, alpha: 0.9)
obs * opts[:alpha] + acc * (1 - opts[:alpha])
end
end
```
<!-- livebook:{"output":true} -->
```
{:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:running_ema, 4}}
```
Your accumulator must be an arity-3 function which accepts the current accumulated value, the current observation, and the current iteration and returns the aggregated metric. You can pass a function direct as an accumulator in your metric:
```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
output_transform = fn %{model_state: model_state} ->
[model_state["dense_0"]["kernel"]]
end
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(
&Nx.mean/1,
"dense_0_kernel_ema_mean",
&CustomAccumulator.running_ema/3,
output_transform
)
```
<!-- livebook:{"output":true} -->
```
#Axon.Loop<
handlers: %{
completed: [],
epoch_completed: [
{#Function<23.77614421/1 in Axon.Loop.log/5>,
#Function<5.77614421/1 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<23.77614421/1 in Axon.Loop.log/5>,
#Function<3.77614421/1 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
metrics: %{
"dense_0_kernel_ema_mean" => {#Function<12.77614421/3 in Axon.Loop.build_metric_fn/3>,
&Nx.mean/1},
"loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
#Function<6.77614421/2 in Axon.Loop.build_loss_fn/1>}
},
...
>
```
Then when you run the loop, Axon will use your custom accumulator:
```elixir
train_data =
Stream.repeatedly(fn ->
xs = Nx.random_normal({8, 1})
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
```
<!-- livebook:{"output":true} -->
```
Epoch: 0, Batch: 1000, dense_0_kernel_ema_mean: 0.2137861 loss: 0.0709054
```
<!-- livebook:{"output":true} -->
```
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.08160790055990219, -0.21322371065616608, -0.1431925743818283, 0.2848915755748749, -0.007875560782849789, 0.3923396170139313, -0.04444991424679756, 0.23083189129829407]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.6269387006759644, 0.3289071023464203, 0.19450749456882477, 0.7400281429290771, 0.23878233134746552, 0.36140456795692444, 0.10503113269805908, 0.3685782253742218]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.2350393682718277, 0.06712433695793152, -0.03675961494445801, -0.06366443634033203]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.35826751589775085, -0.10699580609798431, -0.3681609034538269, 0.08517063409090042],
[-0.7694831490516663, 0.13644370436668396, -0.2390032261610031, 0.6069303154945374],
[-0.6424086689949036, 0.13374455273151398, -0.35404452681541443, 0.6343701481819153],
[-0.09528166800737381, 0.7048070430755615, 0.13699916005134583, 0.6482889652252197],
[-0.08044164627790451, 0.010588583536446095, 0.11140558868646622, 0.33911004662513733],
[0.7361723780632019, 0.757600724697113, -0.0011848200811073184, 0.2799053192138672],
[0.3472788631916046, -0.5225644111633301, 0.04859891161322594, -0.4931156039237976],
[0.09371320903301239, 0.5478940606117249, 0.5831385254859924, -0.21019525825977325]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.835706889629364]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[1.0109968185424805],
[0.574639618396759],
[-0.01302765030413866],
[-0.008134203962981701]
]
>
}
}
```