🐬
【TF Tutorial】Chapter 4: Data Preprocessing and Augmentation
4. Data Preprocessing and Augmentation
4.1 Data Loading with TensorFlow Datasets (tf.data)
TensorFlow Datasets handle downloading and preparing the data and constructing a tf.data.Dataset.
・Loading Dataset
import tensorflow_datasets as tfds
dataset, info = tfds.load('mnist', with_info=True, as_supervised=True)
・Creating a Dataset Pipeline
We can build efficient data pipelines using tf.data.Dataset
objects.
train_ds = dataset['train'].shuffle(10000).batch(32).prefetch(tf.data.AUTOTUNE)
test_ds = dataset['test'].batch(32).prefetch(tf.data.AUTOTUNE)
Here, the training data is shuffled and batched, and the prefetch method is used to load data in the background while the model is training.
4.2 Data Preprocessing: Normalization, One-Hot Encoding
We can apply the preprocessing with our function that takes augment as (data, label) and returns preprocessed (data, label).
・Normalization
def normalize_image(image, label):
image = tf.cast(image, tf.float32) / 255.0
return image, label
train_ds = train_ds.map(normalize_image)
test_ds = test_ds.map(normalize_image)
・One-Hot Encoding
def one_hot_encode(image, label):
label = tf.one_hot(label, depth=10)
return image, label
train_ds = train_ds.map(one_hot_encode)
test_ds = test_ds.map(one_hot_encode)
4.3 Image Data Augmentation
We can augment a similar way to preprocessing.
・Basic Augmentation Techniques:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# Apply the augmentations to a dataset
augmented_train_ds = train_ds.map(lambda x, y: (datagen.random_transform(x), y))
Or, can use keras layers.
・Augmentation with Keras Layers
from tensorflow.keras.layers.experimental import preprocessing
data_augmentation = tf.keras.Sequential([
preprocessing.RandomFlip('horizontal'),
preprocessing.RandomRotation(0.2),
preprocessing.RandomZoom(0.2),
])
def augment(image, label):
image = data_augmentation(image)
return image, label
train_ds = train_ds.map(augment)
Discussion