# Introduction to Nx
```elixir
Mix.install([
{:nx, "~> 0.5"}
])
```
## Numerical Elixir
Elixir's primary numerical datatypes and structures are not optimized
for numerical programming. Nx is a library built to bridge that gap.
[Elixir Nx](https://github.com/elixir-nx/nx) is a numerical computing library
to smoothly integrate to typed, multidimensional data implemented on other
platforms (called tensors). This support extends to the compilers and
libraries that support those tensors. Nx has three primary capabilities:
* In Nx, tensors hold typed data in multiple, named dimensions.
* Numerical definitions, known as `defn`, support custom code with
tensor-aware operators and functions.
* [Automatic differentiation](https://arxiv.org/abs/1502.05767), also known as
autograd or autodiff, supports common computational scenarios
such as machine learning, simulations, curve fitting, and probabilistic models.
Here's more about each of those capabilities. Nx [tensors]() can hold
unsigned integers (u8, u16, u32, u64),
signed integers (s8, s16, s32, s64),
floats (f32, f64), brain floats (bf16), and complex (c64, c128).
Tensors support backends implemented outside of Elixir, including Google's
Accelerated Linear Algebra (XLA) and LibTorch.
Numerical definitions have compiler support to allow just-in-time compilation
that support specialized processors to speed up numeric computation including
TPUs and GPUs.
To know Nx, we'll get to know tensors first. This rapid overview will touch
on the major libraries. Then, future notebooks will take a deep dive into working
with tensors in detail, autograd, and backends. Then, we'll dive into specific
problem spaces like Axon, the machine learning library.
## Nx and tensors
Systems of equations are a central theme in numerical computing.
These equations are often expressed and solved with multidimensional
arrays. For example, this is a two dimensional array:
$$
\begin{bmatrix}
1 & 2 \\\\
3 & 4
\end{bmatrix}
$$
Elixir programmers typically express a similar data structure using
a list of lists, like this:
```elixir
[
[1, 2],
[3, 4]
]
```
This data structure works fine within many functional programming
algorithms, but breaks down with deep nesting and random access.
On top of that, Elixir numeric types lack optimization for many numerical
applications. They work fine when programs
need hundreds or even thousands of calculations. They tend to break
down with traditional STEM applications when a typical problem
needs millions of calculations.
In Nx, we express multi-dimensional data using typed tensors. Simply put,
a tensor is a multi-dimensional array with a predetermined shape and
type. To interact with them, Nx relies on tensor-aware operators rather
than `Enum.map/2` and `Enum.reduce/3`.
In this section, we'll look at some of the various tools for
creating and interacting with tensors. The IEx helpers will assist our
exploration of the core tensor concepts.
```elixir
import IEx.Helpers
```
Now, everything is set up, so we're ready to create some tensors.
<!-- livebook:{"break_markdown":true} -->
### Creating tensors
Start out by getting a feel for Nx through its documentation.
Do so through the IEx helpers, like this:
<!-- livebook:{"disable_formatting":true} -->
```elixir
h Nx
```
Immediately, you can see that tensors are at the center of the
API. The main API for creating tensors is `Nx.tensor/2`:
<!-- livebook:{"disable_formatting":true} -->
```elixir
h Nx.tensor
```
We use it to create tensors from raw Elixir lists of numbers, like this:
```elixir
tensor =
1..4
|> Enum.chunk_every(2)
|> Nx.tensor(names: [:y, :x])
```
The result shows all of the major fields that make up a tensor:
* The data, presented as the list of lists `[[1, 2], [3, 4]]`.
* The type of the tensor, a signed integer 64 bits long, with the type `s64`.
* The shape of the tensor, going left to right, with the outside dimensions listed first.
* The names of each dimension.
We can easily convert it to a binary:
```elixir
binary = Nx.to_binary(tensor)
```
A tensor of type s64 uses eight bytes for each integer. The binary
shows the individual bytes that make up the tensor, so you can see
the integers `1..4` interspersed among the zeros that make
up the tensor. If all of our data only uses positive numbers from
`0..255`, we could save space with a different type:
```elixir
Nx.tensor([[1, 2], [3, 4]], type: :u8) |> Nx.to_binary()
```
If you have already have binary, you can directly convert it to a tensor
by passing the binary and the type:
```elixir
Nx.from_binary(<<0, 1, 2>>, :u8)
```
This function comes in handy when working with published datasets
because they must often be processed. Elixir binaries make quick work
of dealing with numerical data structured for platforms other than
Elixir.
We can get any cell of the tensor:
```elixir
tensor[0][1]
```
Now, try getting the first row of the tensor:
```elixir
# ...your code here...
```
We can also get a whole dimension:
```elixir
tensor[x: 1]
```
or a range:
```elixir
tensor[y: 0..1]
```
Now,
* create your own `{3, 3}` tensor with named dimensions
* return a `{2, 2}` tensor containing the first two columns
of the first two rows
We can get information about this most recent term with
the IEx helper `i`, like this:
<!-- livebook:{"disable_formatting":true} -->
```elixir
i tensor
```
The tensor is a struct that supports the usual `Inspect` protocol.
The struct has keys, but we typically treat the `Nx.Tensor`
as an _opaque data type_ (meaning we typically access the contents and
shape of a tensor using the tensor's API instead of the struct).
Primarily, a tensor is a struct, and the
functions to access it go through a specific backend. We'll get to
the backend details in a moment. For now, use the IEx `h` helper
to get more documentation about tensors. We could also open a Code
cell, type Nx.tensor, and hover the cursor over the word `tensor`
to see the help about that function.
We can get the shape of the tensor with `Nx.shape/1`:
```elixir
Nx.shape(tensor)
```
We can also create a new tensor with a new shape using `Nx.reshape/2`:
```elixir
Nx.reshape(tensor, {1, 9}, names: [:batches, :values])
```
This operation reuses all of the tensor data and simply
changes the metadata, so it has no notable cost.
The new tensor has the same type, but a new shape.
Now, reshape the tensor to contain three dimensions with
one batch, one row, and four columns.
```elixir
# ...your code here...
```
We can create a tensor with named dimensions, a type, a shape,
and our target data. A dimension is called an _axis_, and axes
can have names. We can specify the tensor type and dimension names
with options, like this:
```elixir
Nx.tensor([[1, 2, 3]], names: [:rows, :cols], type: :u8)
```
We created a tensor of the shape `{1, 3}`, with the type `u8`,
the values `[1, 2, 3]`, and two axes named `rows` and `cols`.
Now we know how to create tensors, so it's time to do something with them.
## Tensor aware functions
In the last section, we created a `s64[2][2]` tensor. In this section,
we'll use Nx functions to work with it. Here's the value of `tensor`:
```elixir
tensor
```
We can use `IEx.Helpers.exports/1` or code completion to find
some functions in the `Nx` module that operate on tensors:
<!-- livebook:{"disable_formatting":true} -->
```elixir
exports Nx
```
You might recognize that many of those functions have names that
suggest that they would work on primitive values, called scalars.
Indeed, a tensor can be a scalar:
```elixir
pi = Nx.tensor(3.1415, type: :f32)
```
Take the cosine:
```elixir
Nx.cos(pi)
```
That function took the cosine of `pi`. We can also call them
on a whole tensor, like this:
```elixir
Nx.cos(tensor)
```
We can also call a function that aggregates the contents
of a tensor. For example, to get a sum of the numbers
in `tensor`, we can do this:
```elixir
Nx.sum(tensor)
```
That's `1 + 2 + 3 + 4`, and Nx went to multiple dimensions to get that sum.
To get the sum of values along the `x` axis instead, we'd do this:
```elixir
Nx.sum(tensor, axes: [:x])
```
Nx sums the values across the `x` dimension: `1 + 2` in the first row
and `3 + 4` in the second row.
Now,
* create a `{2, 2, 2}` tensor
* with the values `1..8`
* with dimension names `[:z, :y, :x]`
* calculate the sums along the `y` axis
```elixir
# ...your code here...
```
Sometimes, we need to combine two tensors together with an
operator. Let's say we wanted to subtract one tensor from
another. Mathematically, the expression looks like this:
$$
\begin{bmatrix}
5 & 6 \\\\
7 & 8
\end{bmatrix} -
\begin{bmatrix}
1 & 2 \\\\
3 & 4
\end{bmatrix} =
\begin{bmatrix}
4 & 4 \\\\
4 & 4
\end{bmatrix}
$$
To solve this problem, subtract each right-hand integer from the
corresponding left-hand integer. Unfortunately, we cannot
use Elixir's built-in subtraction operator as it is not tensor-aware.
Luckily, we can use the `Nx.subtract/2` function to solve the
problem:
```elixir
tensor2 = Nx.tensor([[5, 6], [7, 8]])
Nx.subtract(tensor2, tensor)
```
We get a `{2, 2}` shaped tensor full of fours, exactly as we expected.
When calling `Nx.subtract/2`, both operands had the same shape.
Sometimes, you might want to process functions where the dimensions
don't match. To solve this problem, Nx takes advantage of
a concept called _broadcasting_.
## Broadcasts
Often, the dimensions of tensors in an operator don't match.
For example, you might want to subtract a `1` from every
element of a `{2, 2}` tensor, like this:
$$
\begin{bmatrix}
1 & 2 \\\\
3 & 4
\end{bmatrix} - 1 =
\begin{bmatrix}
0 & 1 \\\\
2 & 3
\end{bmatrix}
$$
Mathematically, it's the same as this:
$$
\begin{bmatrix}
1 & 2 \\\\
3 & 4
\end{bmatrix} -
\begin{bmatrix}
1 & 1 \\\\
1 & 1
\end{bmatrix} =
\begin{bmatrix}
0 & 1 \\\\
2 & 3
\end{bmatrix}
$$
That means we need a way to convert `1` to a `{2, 2}` tensor.
`Nx.broadcast/2` solves that problem. This function takes
a tensor or a scalar and a shape.
```elixir
Nx.broadcast(1, {2, 2})
```
This broadcast takes the scalar `1` and translates it
to a compatible shape by copying it. Sometimes, it's easier
to provide a tensor as the second argument, and let `broadcast/2`
extract its shape:
```elixir
Nx.broadcast(1, tensor)
```
The code broadcasts `1` to the shape of `tensor`. In many operators
and functions, the broadcast happens automatically:
```elixir
Nx.subtract(tensor, 1)
```
This result is possible because Nx broadcasts _both tensors_
in `subtract/2` to compatible shapes. That means you can provide
scalar values as either argument:
```elixir
Nx.subtract(10, tensor)
```
Or subtract a row or column. Mathematically, it would look like this:
$$
\begin{bmatrix}
1 & 2 \\\\
3 & 4
\end{bmatrix} -
\begin{bmatrix}
1 & 2
\end{bmatrix} =
\begin{bmatrix}
0 & 0 \\\\
2 & 2
\end{bmatrix}
$$
which is the same as this:
$$
\begin{bmatrix}
1 & 2 \\\\
3 & 4
\end{bmatrix} -
\begin{bmatrix}
1 & 2 \\\\
1 & 2
\end{bmatrix} =
\begin{bmatrix}
0 & 0 \\\\
2 & 2
\end{bmatrix}
$$
This rewrite happens in Nx too, also through a broadcast. We want to
broadcast the tensor `[1, 2]` to match the `{2, 2}` shape, like this:
```elixir
Nx.broadcast(Nx.tensor([1, 2]), {2, 2})
```
The `subtract` function in `Nx` takes care of that broadcast
implicitly, as before:
```elixir
Nx.subtract(tensor, Nx.tensor([1, 2]))
```
The broadcast worked as advertised, copying the `[1, 2]` row
enough times to fill a `{2, 2}` tensor. A tensor with a
dimension of `1` will broadcast to fill the tensor:
```elixir
[[1], [2]] |> Nx.tensor() |> Nx.broadcast({1, 2, 2})
```
```elixir
[[[1, 2, 3]]]
|> Nx.tensor()
|> Nx.broadcast({4, 2, 3})
```
Both of these examples copy parts of the tensor enough
times to fill out the broadcast shape. You can check out the
Nx broadcasting documentation for more details:
<!-- livebook:{"disable_formatting":true} -->
```elixir
h Nx.broadcast
```
Much of the time, you won't have to broadcast yourself. Many of
the functions and operators Nx supports will do so automatically.
We can use tensor-aware operators via various `Nx` functions and
many of them implicitly broadcast tensors.
Throughout this section, we have been invoking `Nx.subtract/2` and
our code would be more expressive if we could use its equivalent
mathematical operator. Fortunately, Nx provides a way. Next, we'll
dive into numerical definitions using `defn`.
## Numerical definitions (defn)
The `defn` macro simplifies the expression of mathematical formulas
containing tensors. Numerical definitions have two primary benefits
over classic Elixir functions.
* They are _tensor-aware_. Nx replaces operators like `Kernel.-/2`
with the `Defn` counterparts — which in turn use `Nx` functions
optimized for tensors — so the formulas we express can use
tensors out of the box.
* `defn` definitions allow for building computation graph of all the
individual operations and using a just-in-time (JIT) compiler to emit
highly specialized native code for the desired computation unit.
We don't have to do anything special to get access to
get tensor awareness beyond importing `Nx.Defn` and writing
our code within a `defn` block.
To use Nx in a Mix project or a notebook, we need to include
the `:nx` dependency and import the `Nx.Defn` module. The
dependency is already included, so import it in a Code cell,
like this:
```elixir
import Nx.Defn
```
Just as the Elixir language supports `def`, `defmacro`, and `defp`,
Nx supports `defn`. There are a few restrictions. It allows only
numerical arguments in the form of primitives or tensors as arguments
or return values, and supports only a subset of the language.
The subset of Elixir allowed within `defn` is quite broad, though. We can
use macros, pipes, and even conditionals, so we're not giving up
much when you're declaring mathematical functions.
Additionally, despite these small concessions, `defn` provides huge benefits.
Code in a `defn` block uses tensor aware operators and types, so the math
beneath your functions has a better chance to shine through. Numerical
definitions can also run on accelerated numerical processors like GPUs and
TPUs. Here's an example numerical definition:
```elixir
defmodule TensorMath do
import Nx.Defn
defn subtract(a, b) do
a - b
end
end
```
This module has a numerical definition that will be compiled.
If we wanted to specify a compiler for this module, we could add
a module attribute before the `defn` clause. One of such compilers
is [the EXLA compiler](https://github.com/elixir-nx/nx/tree/main/exla).
You'd add the `mix` dependency for EXLA and do this:
<!-- livebook:{"force_markdown":true} -->
```elixir
@defn_compiler EXLA
defn subtract(a, b) do
a - b
end
```
Now, it's your turn. Add a `defn` to `TensorMath`
that accepts two tensors representing the lengths of sides of a
right triangle and uses the pythagorean theorem to return the
[length of the hypotenuse](https://www.mathsisfun.com/pythagoras.html).
Add your function directly to the previous Code cell.
The last major feature we'll cover is called auto-differentiation, or autograd.
## Automatic differentiation (autograd)
An important mathematical property for a function is the
rate of change, or the gradient. These gradients are critical
for solving systems of equations and building probabilistic
models. In advanced math, derivatives, or differential equations,
are used to take gradients. Nx can compute these derivatives
automatically through a feature called automatic differentiation,
or autograd.
Here's how it works.
<!-- livebook:{"disable_formatting":true} -->
```elixir
h Nx.Defn.grad
```
We'll build a module with a few functions,
and then create another function to create the gradients of those
functions. The function `grad/1` takes a function, and returns
a function returning the gradient. We have two functions: `poly/1`
is a simple numerical definition, and `poly_slope_at/1` returns
its gradient:
$$
poly: f(x) = 3x^2 + 2x + 1 \\\\
$$
$$
polySlopeAt: g(x) = 6x + 2
$$
Here's the Elixir equivalent of those functions:
```elixir
defmodule Funs do
defn poly(x) do
3 * Nx.pow(x, 2) + 2 * x + 1
end
defn poly_slope_at(x) do
grad(&poly/1).(x)
end
end
```
Notice the second `defn`. It uses `grad/1` to take its
derivative using autograd. It uses the intermediate `defn` AST
and mathematical composition to compute the derivative. You can
see it at work here:
```elixir
Funs.poly_slope_at(2)
```
Nice. If you plug the number 2 into the function $6x + 2$
you get 14! Said another way, if you look at the graph at
exactly 2, the rate of increase is 14 units of `poly(x)`
for every unit of `x`, precisely at `x`.
Nx also has helpers to get gradients corresponding to a number of inputs.
These come into play when solving systems of equations.
Now, you try. Find a function computing the gradient of a `sin` wave.
```elixir
# your code here
```