<!-- livebook:{"app_settings":{"slug":"osergioneto"},"persist_outputs":true} -->
# Fine-tuning
```elixir
Mix.install([
{:bumblebee, "~> 0.3.0"},
{:axon, "~> 0.5.1"},
{:nx, "~> 0.5.1"},
{:exla, "~> 0.5.1"},
{:explorer, "~> 0.5.0"}
])
```
<!-- livebook:{"output":true} -->
```
Resolving Hex dependencies...
Resolution completed in 0.174s
New:
axon 0.5.1
bumblebee 0.3.0
castore 1.0.2
complex 0.5.0
decimal 2.1.1
elixir_make 0.7.6
exla 0.5.3
explorer 0.5.7
jason 1.4.0
nx 0.5.3
nx_image 0.1.1
nx_signal 0.1.0
progress_bar 2.0.1
rustler_precompiled 0.6.1
table 0.1.2
table_rex 3.1.1
telemetry 1.2.1
tokenizers 0.3.2
unpickler 0.1.0
unzip 0.8.0
xla 0.4.4
* Getting bumblebee (Hex package)
* Getting axon (Hex package)
* Getting nx (Hex package)
* Getting exla (Hex package)
* Getting explorer (Hex package)
* Getting rustler_precompiled (Hex package)
* Getting table (Hex package)
* Getting table_rex (Hex package)
* Getting castore (Hex package)
* Getting elixir_make (Hex package)
* Getting telemetry (Hex package)
* Getting xla (Hex package)
* Getting complex (Hex package)
* Getting jason (Hex package)
* Getting nx_image (Hex package)
* Getting nx_signal (Hex package)
* Getting progress_bar (Hex package)
* Getting tokenizers (Hex package)
* Getting unpickler (Hex package)
* Getting unzip (Hex package)
* Getting decimal (Hex package)
==> unzip
Compiling 5 files (.ex)
Generated unzip app
==> decimal
Compiling 4 files (.ex)
Generated decimal app
==> progress_bar
Compiling 10 files (.ex)
Generated progress_bar app
==> table
Compiling 5 files (.ex)
Generated table app
===> Analyzing applications...
===> Compiling telemetry
==> jason
Compiling 10 files (.ex)
Generated jason app
==> unpickler
Compiling 3 files (.ex)
Generated unpickler app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 31 files (.ex)
Generated nx app
==> nx_image
Compiling 1 file (.ex)
Generated nx_image app
==> nx_signal
Compiling 3 files (.ex)
Generated nx_signal app
==> table_rex
Compiling 7 files (.ex)
Generated table_rex app
==> axon
Compiling 23 files (.ex)
Generated axon app
==> castore
Compiling 1 file (.ex)
Generated castore app
==> rustler_precompiled
Compiling 4 files (.ex)
Generated rustler_precompiled app
==> explorer
Compiling 19 files (.ex)
19:52:00.047 [debug] Copying NIF from cache and extracting to /Users/sn/Library/Caches/mix/installs/elixir-1.14.4-erts-13.2.1/38bf24ebae56c4829a860bb9639bd3f8/_build/dev/lib/explorer/priv/native/libexplorer-v0.5.7-nif-2.16-aarch64-apple-darwin.so
Generated explorer app
==> tokenizers
Compiling 7 files (.ex)
19:52:01.034 [debug] Copying NIF from cache and extracting to /Users/sn/Library/Caches/mix/installs/elixir-1.14.4-erts-13.2.1/38bf24ebae56c4829a860bb9639bd3f8/_build/dev/lib/tokenizers/priv/native/libex_tokenizers-v0.3.2-nif-2.16-aarch64-apple-darwin.so
Generated tokenizers app
==> bumblebee
Compiling 89 files (.ex)
Generated bumblebee app
==> elixir_make
Compiling 6 files (.ex)
Generated elixir_make app
==> xla
Compiling 2 files (.ex)
Generated xla app
==> exla
Unpacking /Users/sn/Library/Caches/xla/0.4.4/cache/download/xla_extension-aarch64-darwin-cpu.tar.gz into /Users/sn/Library/Caches/mix/installs/elixir-1.14.4-erts-13.2.1/38bf24ebae56c4829a860bb9639bd3f8/deps/exla/cache
Using libexla.so from /Users/sn/Library/Caches/xla/exla/elixir-1.14.4-erts-13.2.1-xla-0.4.4-exla-0.5.3-yywvdsvg3ykdeecnbffiqglm6i/libexla.so
Compiling 21 files (.ex)
Generated exla app
```
<!-- livebook:{"output":true} -->
```
:ok
```
## Introduction
```elixir
Nx.default_backend(EXLA.Backend)
```
<!-- livebook:{"output":true} -->
```
{Nx.BinaryBackend, []}
```
Fine-tuning is the process of specializing the parameters in a pre-trained model to a specific task. Large-language models such as BERT train on a generic langauge-modeling task which makes them powerful at extracting features from text. Despite their power, you often still need to train them on a downstream task.
This example demonstrates how to use Bumblebee and Axon to fine-tune a pre-trained Bert model to classify Yelp reviews into classes of 1-5 stars. This example is based on [Fine-tune a pretrained model](https://huggingface.co/docs/transformers/training) from Huggingface.
You'll need to first download the Yelp Reviews dataset ([download](https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz)).
Once downloaded, extract it to a directory of your choosing and you're ready to go!
## Load a model
We'll start by loading a pre-trained model and tokenizer; however, we'll initialize the model to have an untrained sequence classification head.
Reviews in the dataset can have anywhere from 1 to 5 stars, which means we need 5 labels in our sequence classification head. We can change the default configuration by loading the model spec with `Bumblebee.load_spec/2` and making changes to spec properties with `Bumblebee.configure/2`.
The pre-trained model we'll be using is `bert-base-cased`; however, you can use any of the supported models from the HuggingFace Hub.
```elixir
{:ok, spec} =
Bumblebee.load_spec({:hf, "bert-base-cased"},
architecture: :for_sequence_classification
)
spec = Bumblebee.configure(spec, num_labels: 5)
{:ok, model} = Bumblebee.load_model({:hf, "bert-base-cased"}, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-cased"})
```
<!-- livebook:{"output":true} -->
```
|=============================================================| 100% (435.77 MB)
23:46:09.618 [info] TfrtCpuClient created.
23:46:10.748 [debug] the following parameters were missing:
* sequence_classification_head.output.kernel
* sequence_classification_head.output.bias
23:46:10.748 [debug] the following PyTorch parameters were unused:
* cls.predictions.bias
* cls.predictions.decoder.weight
* cls.predictions.transform.LayerNorm.beta
* cls.predictions.transform.LayerNorm.gamma
* cls.predictions.transform.dense.bias
* cls.predictions.transform.dense.weight
* cls.seq_relationship.bias
* cls.seq_relationship.weight
```
<!-- livebook:{"output":true} -->
```
{:ok,
%Bumblebee.Text.BertTokenizer{
tokenizer: #Tokenizers.Tokenizer<[
vocab_size: 28996,
continuing_subword_prefix: "##",
max_input_chars_per_word: 100,
model_type: "bpe",
unk_token: "[UNK]"
]>,
special_tokens: %{cls: "[CLS]", mask: "[MASK]", pad: "[PAD]", sep: "[SEP]", unk: "[UNK]"}
}}
```
## Prepare a dataset
With the models downloaded and ready to go, you need to prepare the dataset. The downloaded dataset is a CSV. You can use the `Explorer` library to quickly load the CSV into a DataFrame.
Once the data is loaded, you need to convert raw text to tokens and the raw labels to tensors. Additionally, you need to convert the DataFrame to a Stream consisting of tuples: `{tokenized, labels}` - that is the form expected by Axon's training API.
```elixir
defmodule Yelp do
def load(path, tokenizer, opts \\ []) do
path
|> Explorer.DataFrame.from_csv!(header: false)
|> Explorer.DataFrame.rename(["label", "text"])
|> stream()
|> tokenize_and_batch(tokenizer, opts[:batch_size], opts[:sequence_length])
end
def stream(df) do
xs = df["text"]
ys = df["label"]
xs
|> Explorer.Series.to_enum()
|> Stream.zip(Explorer.Series.to_enum(ys))
end
def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length) do
stream
|> Stream.chunk_every(batch_size)
|> Stream.map(fn batch ->
{text, labels} = Enum.unzip(batch)
tokenized = Bumblebee.apply_tokenizer(tokenizer, text, length: sequence_length)
{tokenized, Nx.stack(labels)}
end)
end
end
```
<!-- livebook:{"output":true} -->
```
{:module, Yelp, <<70, 79, 82, 49, 0, 0, 13, ...>>, {:tokenize_and_batch, 4}}
```
Now you can use the `Yelp.load/2` function to load a training set and a testing set:
```elixir
batch_size = 32
sequence_length = 64
train_data =
Yelp.load("/Users/sn/Downloads/yelp_review_full_csv/train.csv", tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)
test_data =
Yelp.load("/Users/sn/Downloads/yelp_review_full_csv/test.csv", tokenizer,
batch_size: batch_size,
sequence_length: sequence_length
)
```
<!-- livebook:{"output":true} -->
```
#Stream<[
enum: #Stream<[
enum: #Function<73.57817549/2 in Stream.zip_with/2>,
funs: [#Function<3.57817549/1 in Stream.chunk_while/4>]
]>,
funs: [#Function<48.57817549/1 in Stream.map/2>]
]>
```
You can see what a single batch looks like by grabbing 1 from the stream:
```elixir
Enum.take(train_data, 1)
```
<!-- livebook:{"output":true} -->
```
[
{%{
"attention_mask" => #Nx.Tensor<
u32[32][64]
EXLA.Backend<host:0, 0.3007920939.2860646413.150020>
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...],
...
]
>,
"input_ids" => #Nx.Tensor<
u32[32][64]
EXLA.Backend<host:0, 0.3007920939.2860646413.150019>
[
[101, 173, 1197, 119, 2284, 2953, 3272, 1917, 178, 1440, 1111, 1107, 170, 1704, 22351, 119, 1119, 112, 188, 3505, 1105, 3123, 1106, 2037, 1106, 1443, 1217, 10063, 4404, 132, 1119, 112, 188, 1579, 1113, 1159, 1107, 3195, 1117, 4420, 132, 1119, 112, 188, 6559, 1114, ...],
...
]
>,
"token_type_ids" => #Nx.Tensor<
u32[32][64]
EXLA.Backend<host:0, 0.3007920939.2860646413.150021>
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...],
...
]
>
},
#Nx.Tensor<
s64[32]
EXLA.Backend<host:0, 0.3007920939.2860646413.150055>
[5, 2, 4, 4, 1, 5, 5, 1, 2, 3, 1, 1, 4, 2, 5, 5, 5, 5, 5, 5, 4, 3, 2, 5, 1, 1, 1, 2, 2, 4, 2, 2]
>}
]
```
The dataset is rather large for CPU training, so we'll just train a small subset (250 training batches and 50 testing batches):
```elixir
train_data = Enum.take(train_data, 250)
test_data = Enum.take(test_data, 50)
:ok
```
<!-- livebook:{"output":true} -->
```
:ok
```
## Train the model
Now we can go about training the model! First, we need to extract the Axon model and parameters from the Bumblebee model map:
```elixir
%{model: model, params: params} = model
model
```
<!-- livebook:{"output":true} -->
```
#Axon<
inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}, "token_type_ids" => {nil, nil}}
outputs: "container_37"
nodes: 790
>
```
The Axon model actually outputs a map with `:logits`, `:hidden_states`, and `:attentions`. You can see this by using `Axon.get_output_shape/2` with an input. This method symbolically executes the graph and gets the resulting shapes:
```elixir
[{input, _}] = Enum.take(train_data, 1)
Axon.get_output_shape(model, input)
```
<!-- livebook:{"output":true} -->
```
%{attentions: #Axon.None<...>, hidden_states: #Axon.None<...>, logits: {32, 5}}
```
For training, we only care about the `:logits` key, so we'll extract that by attaching an `Axon.nx/2` layer to the model:
```elixir
logits_model = Axon.nx(model, & &1.logits)
```
<!-- livebook:{"output":true} -->
```
#Axon<
inputs: %{"attention_head_mask" => {12, 12}, "attention_mask" => {nil, nil}, "input_ids" => {nil, nil}, "position_ids" => {nil, nil}, "token_type_ids" => {nil, nil}}
outputs: "nx_0"
nodes: 791
>
```
Now we can declare our training loop. You can construct Axon training loops using the `Axon.Loop.trainer/3` factory method with a model, loss function, and optimizer. We'll also adjust the log-settings to more frequently log metrics to standard out:
```elixir
loss =
&Axon.Losses.categorical_cross_entropy(&1, &2,
reduction: :mean,
from_logits: true,
sparse: true
)
optimizer = Axon.Optimizers.adam(5.0e-5)
loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
```
<!-- livebook:{"output":true} -->
```
#Axon.Loop<
metrics: %{
"loss" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
#Function<41.3316493/2 in :erl_eval.expr/6>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.24288319/1 in Axon.Loop.log/3>,
#Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.24288319/1 in Axon.Loop.log/3>,
#Function<64.24288319/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
```
The call to trainer just returns a data structure. In Axon, we manipulate this data structure to control different parts of the loop. For example, you can attach metrics:
```elixir
accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)
loop = Axon.Loop.metric(loop, accuracy, "accuracy")
```
<!-- livebook:{"output":true} -->
```
#Axon.Loop<
metrics: %{
"accuracy" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
#Function<41.3316493/2 in :erl_eval.expr/6>},
"loss" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
#Function<41.3316493/2 in :erl_eval.expr/6>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<27.24288319/1 in Axon.Loop.log/3>,
#Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.24288319/1 in Axon.Loop.log/3>,
#Function<64.24288319/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
```
And you can attach event handlers to do certain things, such as serialize the loop state at regular intervals so you don't lose your progress:
```elixir
loop = Axon.Loop.checkpoint(loop, event: :epoch_completed)
```
<!-- livebook:{"output":true} -->
```
#Axon.Loop<
metrics: %{
"accuracy" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
#Function<41.3316493/2 in :erl_eval.expr/6>},
"loss" => {#Function<11.72223263/3 in Axon.Metrics.running_average/1>,
#Function<41.3316493/2 in :erl_eval.expr/6>}
},
handlers: %{
completed: [],
epoch_completed: [
{#Function<17.24288319/1 in Axon.Loop.checkpoint/2>,
#Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>},
{#Function<27.24288319/1 in Axon.Loop.log/3>,
#Function<6.24288319/2 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<27.24288319/1 in Axon.Loop.log/3>,
#Function<64.24288319/2 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
...
>
```
To run the loop, you just need to call `Axon.Loop.run/4`. `Axon.Loop.run/4` takes a loop, input data, and any initial state (in this case initial parameters). You can kind of think of `Axon.Loop.run/4` as an `Enum.reduce/3`. It takes data, an accumulator, and a function - which map to `Loop.run/4` input data, initial state, and the actual loop data structure.
You'll commonly see loops written out in long chains using Elixir's `|>` operator, like this:
```elixir
trained_model_state =
logits_model
|> Axon.Loop.trainer(loss, optimizer, log: 1)
|> Axon.Loop.metric(accuracy, "accuracy")
|> Axon.Loop.checkpoint(event: :epoch_completed)
|> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA, strict?: false)
:ok
```
<!-- livebook:{"output":true} -->
```
00:04:39.518 [debug] Forwarding options: [compiler: EXLA] to JIT compiler
Epoch: 0, Batch: 249, accuracy: 0.3548750 loss: 1.2023211
Epoch: 1, Batch: 77, accuracy: 0.4431090 loss: 1.1546737
```
## Evaluating the model
The training loop returns the final model state after training over your dataset for the given number of epochs. Axon uses the same `Axon.Loop` API to create evaluation loops as well. You can create one with the `Axon.Loop.evaluator/1` factory, instrument it with metrics, and run it on your data with your trained model state:
```elixir
logits_model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(accuracy, "accuracy")
|> Axon.Loop.run(test_data, trained_model_state, compiler: EXLA)
```
<!-- livebook:{"output":true} -->
```
Batch: 49, accuracy: 0.3675000
```
<!-- livebook:{"output":true} -->
```
%{
0 => %{
"accuracy" => #Nx.Tensor<
f32
EXLA.Backend<host:0, 0.446408219.3911319572.169449>
0.36750003695487976
>
}
}
```