🐬

【TF Tutorial】Chapter 1.5: About callbacks / CV

2024/07/14に公開

I'll explain callbacks and Cross-Validation in tensorflow briefly.
These are very useful to training.

1. callbacks

・Example

# ... prepare the tf.dataset

# Normalize target values for training and validation datasets
ds_train_target_normalized = ds_train.map(lambda x, y: (x, (y - mean_y) / stdd_y))
ds_valid_target_normalized = ds_valid.map(lambda x, y: (x, (y - mean_y) / stdd_y))

if isTrain:
    from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping 

    # Model checkpoint callback
    model_checkpoint = ModelCheckpoint(
        filepath='model_epoch_{epoch:02d}.keras',  # Save model with epoch number
        monitor='val_loss',  # Monitor validation loss
        save_best_only=False,  # Save all models, not just the best one
        save_weights_only=False,  # Save the entire model structure and weights
        mode='min',  # 'min' indicates saving when the monitored value decreases
        verbose=2  # Provide detailed logging
    )

    # Early stopping callback
    early_stopping = EarlyStopping(
        monitor='val_loss',  # Monitor validation loss
        patience=early_patience,  # Number of epochs to wait before stopping
        restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored quantity
    )

    # Train the model
    history = model.fit(
        ds_train_target_normalized,  # Training data
        validation_data=ds_valid_target_normalized,  # Validation data
        epochs=epochs,  # Number of epochs to train
        verbose=1 if is_interactive() else 2,  # Verbose output
        callbacks=[early_stopping, model_checkpoint]  # List of callbacks to apply during training
    )

Callbacks have early_stopping and model_checkpoint.
・model_checkpoint: configuration of when to save the model while training.
・earyly_stopping: Stop the training when the model doesn't improve specified epochs(number of patience)

how to set the these:

model.fit(
    callbacks=[early_stopping, model_checkpoint]
)

2. Cross Validation

・Example

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint
import pandas as pd
import os

# Assuming you have a DataFrame df with a "folds" column
def create_model():
    model = Sequential([
        Dense(128, activation='relu', input_shape=(input_dim,)),
        Dense(64, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

def cross_validate(df, n_folds=5):
    accuracies = []
    best_models = []
    
    for fold in range(n_folds):
        train_data = df[df.folds != fold]
        val_data = df[df.folds == fold]
        
        X_train, y_train = train_data.drop(columns=['folds', 'target']), train_data['target']
        X_val, y_val = val_data.drop(columns=['folds', 'target']), val_data['target']
        
        model = create_model()
        
        # Create a checkpoint callback to save the best model
        checkpoint_path = f"best_model_fold_{fold}.h5"
        checkpoint = ModelCheckpoint(checkpoint_path, save_best_only=True, monitor='val_accuracy', mode='max', verbose=0)
        
        model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_val, y_val), callbacks=[checkpoint], verbose=0)
        
        # Load the best model
        best_model = tf.keras.models.load_model(checkpoint_path)
        best_models.append(best_model)
        
        val_loss, val_acc = best_model.evaluate(X_val, y_val, verbose=0)
        accuracies.append(val_acc)
        
    return accuracies, best_models

# Example usage
# df = your DataFrame with 'folds' and 'target' columns
input_dim = df.shape[1] - 2  # Adjust based on your feature columns count
accuracies, best_models = cross_validate(df, n_folds=5)
print("Cross-Validation Accuracies: ", accuracies)
print("Mean Accuracy: ", sum(accuracies) / len(accuracies))

callbacks are useful when like this situation.

It's end in this time.
please use this as a reference if you need.

Discussion