26. Guardar y cargar modelos en PyTorch

26.1 Introducción

Hasta aquí hemos entrenado varios modelos en PyTorch: redes densas, convolucionales y recurrentes. Pero en un proyecto real no tiene sentido entrenar siempre desde cero cada vez que queremos usar un modelo.

Por eso surge una necesidad muy práctica: guardar el modelo entrenado y luego cargarlo cuando queramos volver a usarlo.

Este tema es muy importante porque conecta el entrenamiento con el uso real de los modelos en aplicaciones, pruebas y despliegue.

26.2 Qué significa guardar un modelo

Guardar un modelo significa almacenar en disco la información necesaria para poder recuperarlo más adelante.

En términos prácticos, eso suele implicar guardar los valores aprendidos por la red, es decir, sus pesos y bias.

En algunos casos también conviene guardar más cosas, como el estado del optimizador, la época alcanzada o métricas del entrenamiento.

26.3 Por qué es necesario guardar

Entrenar un modelo puede llevar tiempo, recursos de hardware y bastante experimentación.

Si cada vez que queremos usarlo tuviéramos que volver a entrenarlo, el trabajo sería poco práctico.

Guardar el modelo permite reutilizarlo, compartirlo, evaluarlo después y seguir entrenándolo desde un punto intermedio.

26.4 Dos ideas distintas: guardar pesos o guardar un checkpoint

En PyTorch suele hablarse de dos estrategias comunes:

  • Guardar solo los pesos del modelo.
  • Guardar un checkpoint más completo.

La primera es más simple y suficiente en muchos casos. La segunda resulta útil cuando queremos reanudar el entrenamiento exactamente donde lo dejamos.

26.5 Qué es state_dict

En PyTorch, la forma más habitual de guardar un modelo es mediante su state_dict().

El state_dict es básicamente un diccionario que contiene los tensores asociados a los parámetros aprendidos y, en algunos casos, otros estados internos necesarios.

Es la forma recomendada en la mayoría de los casos para guardar el estado del modelo.

26.6 Por qué no suele recomendarse guardar todo el objeto modelo

Aunque PyTorch permite serializar objetos completos, muchas veces no es la opción más robusta.

Guardar directamente el objeto puede depender más de la estructura exacta del código, de nombres y de cómo esté definida la clase.

En cambio, guardar el state_dict suele ser más claro, más portable y más controlable.

26.7 Guardar un modelo con torch.save

La función habitual para guardar es torch.save().

torch.save(modelo.state_dict(), "modelo.pth")

Aquí estamos guardando el diccionario de parámetros del modelo en un archivo llamado modelo.pth.

26.8 Cargar un modelo con load_state_dict

Para recuperar los pesos guardados, primero debemos crear una instancia de la misma arquitectura y luego cargar el estado:

modelo = MiModelo()
modelo.load_state_dict(torch.load("modelo.pth"))

Después de eso, el modelo vuelve a tener los parámetros entrenados que habíamos guardado.

26.9 La arquitectura debe coincidir

Este punto es muy importante: si guardamos los pesos de una arquitectura, luego debemos cargar esos pesos en una instancia compatible con esa misma arquitectura.

Si la red tiene distinto tamaño o distinta estructura, la carga fallará o no tendrá sentido.

Por eso, guardar un modelo no elimina la necesidad de conservar la definición de la clase o de la arquitectura.

26.10 Qué hacer después de cargar el modelo

Después de cargar un modelo entrenado, normalmente queremos usarlo para inferencia o evaluación.

En ese caso conviene ponerlo en modo evaluación:

modelo.eval()

Esto es especialmente importante si el modelo usa capas como Dropout o Batch Normalization.

26.11 Qué es un checkpoint

Un checkpoint es un guardado más completo que no contiene solo los pesos del modelo.

Puede incluir, por ejemplo:

  • El state_dict del modelo.
  • El state_dict del optimizador.
  • La época actual.
  • La pérdida o la métrica alcanzada.

Esto resulta muy útil si queremos retomar el entrenamiento.

26.12 Guardar un checkpoint

Un checkpoint se suele guardar como un diccionario:

torch.save({ "epoch": epoca, "model_state_dict": modelo.state_dict(), "optimizer_state_dict": optimizador.state_dict() }, "checkpoint.pth")

De esta manera, no guardamos solo el modelo, sino también el contexto del entrenamiento.

26.13 Cargar un checkpoint

Para recuperar un checkpoint, primero lo cargamos y luego reasignamos sus partes:

checkpoint = torch.load("checkpoint.pth")
modelo.load_state_dict(checkpoint["model_state_dict"])
optimizador.load_state_dict(checkpoint["optimizer_state_dict"])

Así el modelo y el optimizador recuperan el estado que tenían cuando se guardó el archivo.

26.14 Cuándo guardar solo el modelo

Guardar solo el state_dict del modelo suele ser suficiente cuando el objetivo es usarlo luego para predicción o evaluación.

Por ejemplo, si ya terminaste de entrenar y solo quieres reutilizar el modelo final, normalmente eso alcanza.

Es la opción más simple y una muy buena práctica para muchos casos.

26.15 Cuándo conviene guardar un checkpoint completo

Si el entrenamiento puede ser largo o si quieres poder retomarlo más adelante, conviene guardar un checkpoint.

Esto es especialmente útil cuando entrenas durante muchas épocas o cuando trabajas con experimentos que podrían interrumpirse.

En esos contextos, guardar solo el modelo puede no ser suficiente.

26.16 Qué extensión usar

En PyTorch se suelen usar extensiones como .pth o .pt.

No hay una diferencia obligatoria en el sentido técnico más básico; ambas suelen verse en la práctica.

Lo importante no es tanto la extensión, sino tener claro qué contiene exactamente el archivo.

26.17 Guardar el mejor modelo

Una práctica muy común es guardar el modelo que obtuvo la mejor pérdida o la mejor métrica de validación.

En lugar de guardar cualquier estado final, se conserva el mejor encontrado durante el entrenamiento.

Esto ayuda a evitar quedarnos con un modelo de una época posterior que tal vez ya empezó a sobreajustar.

26.18 Ejemplo conceptual de mejor modelo

La lógica suele ser algo como esto:

si la perdida de validacion mejora:
    guardar el modelo

Ese criterio hace que el archivo guardado represente el mejor punto del entrenamiento según la métrica elegida.

26.19 Guardar no reemplaza la evaluación

Es importante no confundir el acto de guardar con el proceso de evaluar.

Guardar es una acción técnica para persistir el estado del modelo. Evaluar es una acción analítica para medir su desempeño.

Ambas cosas suelen estar relacionadas, pero no son lo mismo.

26.20 Uso típico en un flujo real

En un flujo de trabajo real, una secuencia razonable puede ser:

  1. Entrenar el modelo.
  2. Evaluarlo en validación.
  3. Guardar el mejor modelo.
  4. Cargarlo luego para inferencia o prueba final.

Esto permite separar claramente el proceso de entrenamiento del de uso posterior del modelo.

26.21 Qué ocurre con CPU y GPU

Cuando se trabaja con CPU y GPU, puede ser necesario indicar cómo cargar el archivo si cambia el dispositivo disponible.

En PyTorch esto suele resolverse con el parámetro map_location de torch.load.

En esta introducción no profundizaremos demasiado en ello, pero conviene saber que existe.

26.22 Por qué este tema es importante para un estudiante

Muchos estudiantes aprenden a entrenar modelos, pero no siempre aprenden de inmediato a guardarlos y reutilizarlos correctamente.

Sin embargo, esa habilidad es esencial para pasar de ejercicios de práctica a aplicaciones reales.

Un modelo que no puede guardarse y cargarse bien es mucho menos útil en la práctica.

26.23 Qué observaremos en la aplicación final

En la aplicación final trabajaremos con dos aplicaciones conectadas entre sí.

  • Una primera aplicación para generar un dataset pequeño, entrenar una CNN, evaluarla y guardar el modelo entrenado en disco.
  • Una segunda aplicación que recupera ese archivo guardado y permite usar el modelo sin volver a entrenarlo.
  • Una interfaz visual donde el usuario puede dibujar líneas y obtener predicciones usando el modelo recuperado.

De esa manera, el tema queda conectado con una práctica concreta y reutilizable: entrenar una vez y usar el modelo después en otra aplicación.

26.24 Por qué el ejemplo será sencillo

El objetivo del ejemplo no es complicar la arquitectura, sino mostrar con claridad el flujo de guardar un modelo y luego recuperarlo en otro programa.

Por eso usaremos un problema visual muy simple: imágenes sintéticas de 8x8 con líneas verticales, horizontales y oblicuas, junto con una CNN pequeña.

Así el foco queda puesto en el mecanismo de persistencia del modelo y en la idea práctica de separar entrenamiento e inferencia en dos aplicaciones distintas.

26.25 Código completo para ejecutar

En este tema desarrollaremos dos aplicaciones.

La primera entrena, evalúa y guarda el modelo. La segunda recupera el modelo almacenado en disco y lo usa para hacer predicciones sobre dibujos realizados por el usuario.

Aplicación 1: entrenar, evaluar y guardar el modelo

import torch
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(24)
torch.set_printoptions(precision=2, sci_mode=False)

def generar_imagen_vertical():
    col = torch.randint(1, 7, (1,)).item()

    matriz = [
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0]
    ]

    for fila in range(8):
        matriz[fila][col] = 1

    img = torch.tensor(matriz, dtype=torch.float32)
    img += 0.10 * torch.rand(8, 8)
    return img.clamp(0.0, 1.0)

def generar_imagen_horizontal():
    fila = torch.randint(1, 7, (1,)).item()

    matriz = [
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0]
    ]

    for col in range(8):
        matriz[fila][col] = 1

    img = torch.tensor(matriz, dtype=torch.float32)
    img += 0.10 * torch.rand(8, 8)
    return img.clamp(0.0, 1.0)

def generar_imagen_oblicua():
    matriz = [
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0],
        [0,0,0,0,0,0,0,0]
    ]

    tipo = torch.randint(0, 2, (1,)).item()

    if tipo == 0:
        # Diagonal principal: \
        for i in range(8):
            matriz[i][i] = 1
    else:
        # Diagonal secundaria: /
        for i in range(8):
            matriz[i][7 - i] = 1

    img = torch.tensor(matriz, dtype=torch.float32)
    img += 0.10 * torch.rand(8, 8)
    return img.clamp(0.0, 1.0)

def generar_dataset(n_por_clase):
    imagenes = []
    etiquetas = []

    for _ in range(n_por_clase):
        imagenes.append(generar_imagen_vertical())
        etiquetas.append(0)

        imagenes.append(generar_imagen_horizontal())
        etiquetas.append(1)

        imagenes.append(generar_imagen_oblicua())
        etiquetas.append(2)

    X = torch.stack(imagenes).unsqueeze(1)   # [N, 1, 8, 8]
    y = torch.tensor(etiquetas, dtype=torch.long)
    return X, y

X_train, y_train = generar_dataset(120)
X_val, y_val = generar_dataset(60)

class CNNPequena(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(16 * 2 * 2, 16)
        self.fc2 = nn.Linear(16, 3)   # 3 clases

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)

        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

def calcular_accuracy(logits, y_real):
    pred = torch.argmax(logits, dim=1)
    return (pred == y_real).float().mean().item()

modelo = CNNPequena()
criterio = nn.CrossEntropyLoss()
optimizador = optim.Adam(modelo.parameters(), lr=0.01)

for epoca in range(120):
    modelo.train()

    logits_train = modelo(X_train)
    loss_train = criterio(logits_train, y_train)

    optimizador.zero_grad()
    loss_train.backward()
    optimizador.step()

    if (epoca + 1) % 20 == 0:
        modelo.eval()
        with torch.no_grad():
            logits_train_eval = modelo(X_train)
            logits_val = modelo(X_val)

            train_loss = criterio(logits_train_eval, y_train).item()
            val_loss = criterio(logits_val, y_val).item()

            train_acc = calcular_accuracy(logits_train_eval, y_train)
            val_acc = calcular_accuracy(logits_val, y_val)

        print(f"Epoca {epoca+1:3d} | train loss={train_loss:.4f} | val loss={val_loss:.4f} | train acc={train_acc:.3f} | val acc={val_acc:.3f}")

print()
with torch.no_grad():
    logits_train = modelo(X_train)
    logits_val = modelo(X_val)

    train_loss = criterio(logits_train, y_train).item()
    val_loss = criterio(logits_val, y_val).item()

    train_acc = calcular_accuracy(logits_train, y_train)
    val_acc = calcular_accuracy(logits_val, y_val)

print("RESUMEN FINAL")
print(f"train loss={train_loss:.4f} | val loss={val_loss:.4f}")
print(f"train acc={train_acc:.3f} | val acc={val_acc:.3f}")

ejemplos_nuevos = torch.stack([
    generar_imagen_vertical(),
    generar_imagen_horizontal(),
    generar_imagen_oblicua()
]).unsqueeze(1)

nombres_clases = ["vertical", "horizontal", "oblicua"]

with torch.no_grad():
    logits = modelo(ejemplos_nuevos)
    probs = torch.softmax(logits, dim=1)
    predicciones = torch.argmax(probs, dim=1)

    print()
    print("Probabilidades para imagenes nuevas:")
    print(probs)

    print("Predicciones finales:")
    for i in range(len(predicciones)):
        clase = predicciones[i].item()
        print(f"Imagen {i+1}: {nombres_clases[clase]}")

# -------------------------------------------------
# Guardar el modelo entrenado
# -------------------------------------------------
torch.save(modelo.state_dict(), "modelo_lineas.pth")
print()
print("Modelo guardado en el archivo: modelo_lineas.pth")

Aplicación 2: recuperar el modelo guardado y usarlo

import tkinter as tk
from tkinter import messagebox
import torch
import torch.nn as nn

torch.set_printoptions(precision=2, sci_mode=False)

# ---------------------------------------------------------
# Definición del mismo modelo usado durante el entrenamiento
# ---------------------------------------------------------
class CNNPequena(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(16 * 2 * 2, 16)
        self.fc2 = nn.Linear(16, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)

        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


# ---------------------------------------------------------
# Aplicación Tkinter
# ---------------------------------------------------------
class AplicacionDibujo:
    def __init__(self, ventana):
        self.ventana = ventana
        self.ventana.title("Predicción de líneas con PyTorch")

        self.filas = 8
        self.columnas = 8
        self.tamano_celda = 50

        self.nombres_clases = ["vertical", "horizontal", "oblicua"]

        # Matriz lógica donde guardamos lo dibujado
        self.matriz = [[0 for _ in range(self.columnas)] for _ in range(self.filas)]

        # Cargar modelo
        self.modelo = CNNPequena()
        try:
            self.modelo.load_state_dict(torch.load("modelo_lineas.pth", map_location="cpu"))
            self.modelo.eval()
        except FileNotFoundError:
            messagebox.showerror(
                "Error",
                "No se encontró el archivo modelo_lineas.pth\n\n"
                "Primero debes ejecutar la aplicación de entrenamiento para guardarlo."
            )
            self.ventana.destroy()
            return

        # Canvas
        self.canvas = tk.Canvas(
            self.ventana,
            width=self.columnas * self.tamano_celda,
            height=self.filas * self.tamano_celda,
            bg="white"
        )
        self.canvas.pack(padx=10, pady=10)

        # Dibujar cuadrícula
        self.rectangulos = []
        for fila in range(self.filas):
            fila_rects = []
            for col in range(self.columnas):
                x1 = col * self.tamano_celda
                y1 = fila * self.tamano_celda
                x2 = x1 + self.tamano_celda
                y2 = y1 + self.tamano_celda

                rect = self.canvas.create_rectangle(
                    x1, y1, x2, y2,
                    fill="white",
                    outline="gray"
                )
                fila_rects.append(rect)
            self.rectangulos.append(fila_rects)

        # Eventos para dibujar
        self.canvas.bind("<Button-1>", self.pintar)
        self.canvas.bind("<B1-Motion>", self.pintar)

        # Frame de botones
        marco_botones = tk.Frame(self.ventana)
        marco_botones.pack(pady=10)

        boton_predecir = tk.Button(
            marco_botones,
            text="Predecir",
            width=15,
            command=self.predecir
        )
        boton_predecir.grid(row=0, column=0, padx=5)

        boton_limpiar = tk.Button(
            marco_botones,
            text="Limpiar",
            width=15,
            command=self.limpiar
        )
        boton_limpiar.grid(row=0, column=1, padx=5)

        # Etiqueta resultado
        self.label_resultado = tk.Label(
            self.ventana,
            text="Dibuja una línea en la cuadrícula",
            font=("Arial", 14)
        )
        self.label_resultado.pack(pady=10)

    def pintar(self, evento):
        col = evento.x // self.tamano_celda
        fila = evento.y // self.tamano_celda

        if 0 <= fila < self.filas and 0 <= col < self.columnas:
            self.matriz[fila][col] = 1
            self.canvas.itemconfig(self.rectangulos[fila][col], fill="black")

    def limpiar(self):
        for fila in range(self.filas):
            for col in range(self.columnas):
                self.matriz[fila][col] = 0
                self.canvas.itemconfig(self.rectangulos[fila][col], fill="white")

        self.label_resultado.config(text="Dibuja una línea en la cuadrícula")

    def convertir_a_tensor(self):
        img = torch.tensor(self.matriz, dtype=torch.float32)
        img = img.unsqueeze(0).unsqueeze(0)   # [1, 1, 8, 8]
        return img

    def predecir(self):
        # Verificar si el usuario dibujó algo
        total = sum(sum(fila) for fila in self.matriz)
        if total == 0:
            messagebox.showwarning("Atención", "Debes dibujar algo antes de predecir.")
            return

        entrada = self.convertir_a_tensor()

        with torch.no_grad():
            logits = self.modelo(entrada)
            probs = torch.softmax(logits, dim=1)
            pred = torch.argmax(probs, dim=1).item()

        texto = (
            f"Predicción: {self.nombres_clases[pred]}\n"
            f"vertical={probs[0][0].item():.2f}   "
            f"horizontal={probs[0][1].item():.2f}   "
            f"oblicua={probs[0][2].item():.2f}"
        )
        self.label_resultado.config(text=texto)


# ---------------------------------------------------------
# Programa principal
# ---------------------------------------------------------
ventana = tk.Tk()
app = AplicacionDibujo(ventana)
ventana.mainloop()

Este ejemplo muestra un flujo mucho más realista: una aplicación entrena y guarda el modelo, y otra aplicación independiente lo recupera desde disco para usarlo sin necesidad de volver a entrenar.

26.26 Errores comunes al guardar y cargar modelos

  • Intentar cargar pesos en una arquitectura incompatible.
  • Olvidar usar model.eval() después de cargar.
  • No distinguir entre guardar solo el modelo y guardar un checkpoint completo.
  • Perder el archivo o no registrar bien su ruta.
  • Usar el conjunto de prueba para decidir qué modelo guardar.

26.27 Buenas prácticas para estudiantes

Si estás empezando, estas prácticas suelen ayudarte mucho:

  • Guardar el mejor modelo según validación.
  • Usar state_dict como enfoque principal.
  • Nombrar claramente los archivos guardados.
  • Conservar la definición de la arquitectura junto con el proyecto.
  • Verificar después de cargar que el rendimiento siga siendo coherente.

26.28 Qué debes recordar de este tema

  • Guardar un modelo permite reutilizarlo sin volver a entrenarlo.
  • En PyTorch, la práctica habitual es guardar el state_dict.
  • Para cargarlo, primero se crea una instancia compatible de la arquitectura.
  • torch.save y torch.load son las funciones clave.
  • Un checkpoint puede guardar información adicional como el estado del optimizador y la época.
  • Después de cargar un modelo para inferencia, conviene usar model.eval().

26.29 Cierre conceptual

Guardar y cargar modelos es una habilidad básica pero esencial en PyTorch. Es el puente entre entrenar una red y poder usarla de verdad en otro momento, en otra máquina o en una aplicación concreta.

Dominar este tema significa dar un paso importante hacia un flujo de trabajo más realista y profesional dentro del Deep Learning.