Para entrenar modelos a escala necesitas dominar unos pocos conceptos que definen velocidad, memoria y estabilidad. Esta guia recorre precision numerica, paralelismo y estrategias de cuantizacion que realmente se usan en produccion.

Precision numerica

El formato numerico que elijas (FP32, FP16, BF16, FP8, INT8, etc.) impacta directamente en throughput, uso de memoria y estabilidad del entrenamiento. Si queres escalar, no es opcional entenderlo.

Por que importa la precision

Una computadora es finita. Los numeros reales no. Si queres representar algo como $\pi$, que tiene infinitos decimales, necesitas aproximarlo. El punto flotante es exactamente ese esquema de aproximacion.

En floating point, un valor real se representa con tres partes: signo, mantisa (resolucion fina) y exponente (rango dinamico). En IEEE 754 la base es 2, asi que el modelo habitual es:

$$ \text{value} \approx \text{sign}\times \text{mantissa}\times \text{base}^{\text{exponent}}. $$

Mas bits en mantisa dan mas resolucion. Mas bits en exponente dan mayor rango.

32-bit floating-point layout (FP32)
Figura 1. Layout FP32 bajo IEEE 754.

Esto importa por tres motivos concretos: los formatos de pocos bits son mas rapidos y consumen menos memoria, un rango insuficiente genera overflow/underflow, y para entrenar LLMs grandes sin costos absurdos casi siempre terminas en precision mixta.

BF16 vs FP32 vs FP16
Figura 2. Formatos mas usados en deep learning: BF16, FP32, FP16.

FP32 (single precision)

Es la linea base. Usa 1 bit de signo, 8 de exponente y 23 de mantisa. El rango ronda $1.18\times10^{-38}$ a $3.4\times10^{38}$, con epsilon de maquina cercano a $1.19\times10^{-7}$. Incluso si guardas tensores en menor precision, muchas acumulaciones criticas se hacen en FP32 para mantener estabilidad.

FP16 (half)

FP16 reduce memoria y suele acelerar en hardware compatible. El costo: menos detalle y menor rango. Valores chicos pueden desaparecer y valores grandes saturarse. Por eso aparece el loss scaling: escalas temporalmente loss y gradientes, y luego deshaces ese escalado al actualizar.

BF16 (brain floating point)

BF16 es hoy el formato de batalla en TPUs y GPUs modernas (como H100). Mantiene exponente tipo FP32 (mismo rango dinamico) pero recorta mantisa. En la practica evita muchos problemas de overflow/underflow sin trucos extra, a cambio de perder algo de precision decimal.

Comparando precisiones

Para ejemplos rapidos voy a usar JAX. Es comodo para correr en GPU/TPU y facilita inspeccionar comportamiento de memoria y performance.

Primero, version, backend y dispositivos:

import jax
# JAX setup
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

Voy a medir tiempo y memoria con psutil, tracemalloc y time.

import jax.numpy as jnp
import psutil
import tracemalloc
import time

def get_memory_usage():
   """Return current process RSS in MB."""
   process = psutil.Process()
   return process.memory_info().rss / 1024 / 1024

def measure_memory_and_time(func):
   def wrapper(*args, **kwargs):
       tracemalloc.start()
       start_memory = get_memory_usage()
       start_time = time.time()
       result = func(*args, **kwargs)
       jax.block_until_ready(result)
       end_time = time.time()
       end_memory = get_memory_usage()
       current, peak = tracemalloc.get_traced_memory()
       tracemalloc.stop()
       return {
           'result': result,
           'execution_time': end_time - start_time,
           'memory_delta': end_memory - start_memory,
           'peak_memory': peak / 1024 / 1024,
           'current_memory': current / 1024 / 1024
       }
   return wrapper

Ahora un micro-benchmark de multiplicacion de matrices y un forward pass chico.

@measure_memory_and_time
def matrix_multiplication_test(dtype, shape):
   """Matrix multiply at a given dtype."""
   key = jax.random.PRNGKey(42)
   key1, key2 = jax.random.split(key)
   def matmul_operation():
       a = jax.random.normal(key1, shape, dtype=dtype)
       b = jax.random.normal(key2, shape, dtype=dtype)
       return jnp.dot(a, b)
   return matmul_operation()

@measure_memory_and_time
def neural_network_forward_pass_test(dtype, input_size, hidden_size, output_size):
   """One forward pass of a tiny MLP."""
   key = jax.random.PRNGKey(123)
   keys = jax.random.split(key, 3)
   def nn_forward():
       # Weight init
       W1 = jax.random.normal(keys[0], (input_size, hidden_size), dtype=dtype)
       b1 = jax.random.normal(keys[1], (hidden_size,), dtype=dtype)
       W2 = jax.random.normal(keys[2], (hidden_size, output_size), dtype=dtype)
       b2 = jax.random.normal(keys[2], (output_size,), dtype=dtype)
       # Input
       x = jax.random.normal(keys[0], (input_size,), dtype=dtype)
       # Forward
       h = jnp.tanh(jnp.dot(x, W1) + b1)
       y = jnp.dot(h, W2) + b2
       return y
   return nn_forward()

Ejemplo de uso:

matrix_multiplication_test(jnp.float16, (5000, 5000))
neural_network_forward_pass_test(jnp.float16, 784, 256, 10)

Tambien incluyo FP64 para mostrar un resultado que a primera vista puede parecer contraintuitivo. En un M1 se puede ver algo asi (tiempo en segundos, memoria en MB):

Multiplicacion de matrices

PrecisionTiempoPico de memoria
FP161.0240.013
BF160.9780.008
FP320.9430.008
FP640.9280.009

Red neuronal

PrecisionTiempoPico de memoria
FP160.0070.020
BF160.0040.015
FP320.0030.015
FP640.0020.016

Si te quedas solo con esto pareceria que toda la discusion no sirve porque FP64 fue mas rapido. Esa lectura es incorrecta. En CPU (como M1), bibliotecas del vendor suelen estar muy optimizadas para FP32/FP64. FP16/BF16 pueden implicar conversiones internas que agregan overhead, especialmente en problemas chicos. En GPU la historia se invierte porque Tensor Cores estan optimizados para FP16/BF16 y menor movimiento de datos.

Si repetis estas pruebas en GPU y con matrices mas grandes, aparecen los speedups esperados. Un comparativo de TFLOPs lo muestra bien:

Throughput vs precision
Figura 3. El throughput crece a medida que baja la precision en operaciones matriciales.

Para cerrar esta parte, un ejemplo minimo de precision mixta: calculo sensible en FP32 y almacenamiento en FP16.

def mixed_precision_forward_pass(x, W, b):
   # Promote to FP32 for compute
   x_fp32 = x.astype(jnp.float32)
   W_fp32 = W.astype(jnp.float32)
   b_fp32 = b.astype(jnp.float32)
   # Forward
   y = jnp.dot(x_fp32, W_fp32) + b_fp32
   # Store in FP16 to save memory
   return y.astype(jnp.float16)

Cuantizacion

En entrenamiento solemos conservar operaciones criticas en FP32 y usar FP16/BF16 para el resto. En inferencia, el objetivo cambia: menor latencia y menor memoria sin destruir calidad. Ahi entra cuantizacion, que representa pesos y activaciones con menos bits aceptando un error numerico controlado. Normalmente se empieza con PTQ (Post-Training Quantization). Si no alcanza, se pasa a QAT (Quantization-Aware Training).

QAT versus PTQ

Figura 4. QAT a la izquierda, PTQ a la derecha.

Post-Training Quantization (PTQ)

PTQ toma un modelo ya entrenado y convierte pesos/activaciones a low-bit sin reentrenar. En muchos casos es suficiente.

Dado un valor real $x\in\mathbb{R}$, el cuantizador uniforme afino mapea $x$ a un entero $q$ de $b$ bits usando escala $s$ y zero-point $z$:

$$ q = \operatorname{clip}\Big(\operatorname{round}\big(\tfrac{x}{s}\big) + z,\ q_{\min}, q_{\max}\Big), \qquad \hat{x} = s,(q - z). $$

Para enteros signed de $b$ bits, $q_{\min}=-2^{b-1}$ y $q_{\max}=2^{b-1}-1$. Con un rango real objetivo $[\alpha,\beta]$:

$$ s=\frac{\beta-\alpha}{q_{\max}-q_{\min}}, \qquad z=\operatorname{round}\Big(\frac{-\alpha}{s}\Big)+q_{\min}. $$

Si $z=0$, la cuantizacion es simetrica. Si la distribucion esta sesgada, una forma asimetrica ($z\neq 0$) aprovecha mejor los niveles disponibles.

Referencia simple en Python:

import jax.numpy as jnp

def quantize_tensor(x, num_bits=8, signed=True, eps=1e-8):
    if signed:
        qmin = - (2 ** (num_bits - 1))
        qmax = 2 ** (num_bits - 1) - 1
    else:
        qmin = 0
        qmax = 2 ** num_bits - 1

    x_min = jnp.min(x)
    x_max = jnp.max(x)

    scale = (x_max - x_min) / (qmax - qmin + eps)
    scale = jnp.where(scale == 0, 1.0, scale)

    zero_point = jnp.round(qmin - x_min / (scale + eps))
    zero_point = jnp.clip(zero_point, qmin, qmax)

    q = jnp.clip(jnp.round(x / (scale + eps) + zero_point), qmin, qmax).astype(jnp.int32)
    x_hat = scale * (q.astype(jnp.float32) - zero_point)

    return q, x_hat, float(scale), int(zero_point)

Los pesos suelen funcionar bien con min-max per-tensor o per-channel porque su distribucion esta cerca de cero. Las activaciones son distintas: ReLU produce solo no-negativos, entonces la cuantizacion simetrica desperdicia niveles en negativos que no aparecen. Ahi conviene cuantizacion asimetrica o unsigned, junto con calibracion sobre dataset representativo. Tambien es comun percentile clipping para reducir saturacion por outliers.

Symmetric vs asymmetric quantization

Figura 5. Cuantizacion simetrica vs asimetrica.

Si PTQ pierde demasiada accuracy, toca pasar a QAT.

Quantization-Aware Training (QAT)

QAT entrena el modelo simulando durante training la misma cuantizacion que se usara en inferencia. El objetivo es:

$$ \min_{w}\ \mathbb{E}_{(x,y)\sim\mathcal{D}}\Big[L\big(f_{q(w)}(x),\,y\big)\Big]. $$

Aca $w$ son parametros reales antes de cuantizar. $f_{q(w)}$ es la red con pesos cuantizados por $q(\cdot)$. Si aplicas redondeo/clipping literal, se rompe la diferenciabilidad. El cuantizador uniforme afino tipico es:

$$ q=\operatorname{clip}\!\big(\operatorname{round}(u),\,q_{\min},\,q_{\max}\big),\quad u=\frac{w}{s}+z,\quad \tilde w=s\,(q-z), $$

y en el forward usas $\tilde w$, no $w$. Para backprop se usa el straight-through estimator (STE): se trata redondeo como identidad en backward y clipping con derivada 1 dentro del rango valido y 0 fuera.

$$ \frac{\partial \tilde w}{\partial w}\ \approx\ \begin{cases} 1 & \text{if } q_{\min} < u < q_{\max} \\ 0 & \text{otherwise} \end{cases} $$

Si tambien aprendes la escala $s$, seguis usando STE y obtenes:

$$ \frac{\partial \tilde w}{\partial s}\ \approx\ \begin{cases} -\,z\;-\;\frac{w}{s} & \text{if } q_{\min} < u < q_{\max} \\ \operatorname{clip}(u,\,q_{\min},\,q_{\max})\;-\;z\;-\;\frac{w}{s} & \text{otherwise} \end{cases} $$

En practica, este gradiente se normaliza por la cantidad de elementos que comparten escala. Ahi aparecen esquemas per-tensor y per-channel.

En una capa lineal con pesos $W\in\mathbb{R}^{C_o\times C_i}$ y activaciones $X\in\mathbb{R}^{T\times C_i}$:

  • Per-tensor usa una sola escala para todo $W$ y otra para todo $X$. Es simple y barato, pero puede desperdiciar resolucion.
  • Per-channel usa una escala distinta por canal de salida de $W$ (vector $\Delta_W\in\mathbb{R}^{1\times C_o}$).
  • Si sumas per-token para activaciones, tenes $\Delta_X\in\mathbb{R}^{T\times 1}$.

Esto se adapta mejor a la heterogeneidad real y suele reducir error en 8-bit y sobre todo en 4-bit, con el costo de almacenar mas escalas.

Per-tensor and per-channel quantization

\(X\in\mathbb{R}^{T\times C_i}\) son activaciones y \(W\in\mathbb{R}^{C_o\times C_i}\) son pesos. El diagrama superior muestra per-tensor con escalas unicas \(\Delta X[1]\) y \(\Delta W[1]\). El inferior muestra per-token + per-channel con \(\Delta X[T\times 1]\) y \(\Delta W[1\times C_o]\). Las cajas punteadas indican la region cubierta por cada escala.

El entrenamiento termina siendo riesgo empirico estandar por mini-batches, aplicando en el forward la misma cuantizacion que vas a usar despues. En cada paso calculas salida con $\tilde w=s,(q-z)$, evaluas loss, propagas gradientes con STE, normalizas gradientes de escala y actualizas tanto $w$ como escalas (si son aprendibles):

$$ \min_{w}\ \frac{1}{B}\sum_{i=1}^{B} L\big(f_{\tilde w}(x_i),\,y_i\big). $$