🐬

【TF Tutorial】Chapter 4: Data Preprocessing and Augmentation

2024/07/18に公開

4. Data Preprocessing and Augmentation

4.1 Data Loading with TensorFlow Datasets (tf.data)

TensorFlow Datasets handle downloading and preparing the data and constructing a tf.data.Dataset.

・Loading Dataset

import tensorflow_datasets as tfds

dataset, info = tfds.load('mnist', with_info=True, as_supervised=True)

・Creating a Dataset Pipeline
We can build efficient data pipelines using tf.data.Dataset objects.

train_ds = dataset['train'].shuffle(10000).batch(32).prefetch(tf.data.AUTOTUNE)
test_ds = dataset['test'].batch(32).prefetch(tf.data.AUTOTUNE)

Here, the training data is shuffled and batched, and the prefetch method is used to load data in the background while the model is training.

4.2 Data Preprocessing: Normalization, One-Hot Encoding

We can apply the preprocessing with our function that takes augment as (data, label) and returns preprocessed (data, label).

・Normalization

def normalize_image(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train_ds = train_ds.map(normalize_image)
test_ds = test_ds.map(normalize_image)

・One-Hot Encoding

def one_hot_encode(image, label):
    label = tf.one_hot(label, depth=10)
    return image, label

train_ds = train_ds.map(one_hot_encode)
test_ds = test_ds.map(one_hot_encode)

4.3 Image Data Augmentation

We can augment a similar way to preprocessing.

・Basic Augmentation Techniques:

from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Apply the augmentations to a dataset
augmented_train_ds = train_ds.map(lambda x, y: (datagen.random_transform(x), y))

Or, can use keras layers.
・Augmentation with Keras Layers

from tensorflow.keras.layers.experimental import preprocessing

data_augmentation = tf.keras.Sequential([
    preprocessing.RandomFlip('horizontal'),
    preprocessing.RandomRotation(0.2),
    preprocessing.RandomZoom(0.2),
])

def augment(image, label):
    image = data_augmentation(image)
    return image, label

train_ds = train_ds.map(augment)

Discussion