🐶

【ML】Semantic Segmentation Sample Code

2024/07/28に公開

1. Semantic Segmentation

Semantic Segmentation is a computer vision task in which the goal is to categorize each pixel in an image into a class or object. The goal is to produce a dense pixel-wise segmentation map of an image, where each pixel is assigned to a specific class or object.

2. Code

Use a pre-trained model.

・Semantic Segmentation

import torch
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Load pre-trained DeepLabV3 model
model = deeplabv3_resnet101(pretrained=True)
model.eval()

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def preprocess_image(image_path, target_size=(512, 512)):
    try:
        # Open image and convert to RGB (handles PNG, JPEG, etc.)
        image = Image.open(image_path).convert('RGB')
        width, height = image.size
        print(f'width: {width}')
        print(f'height: {height}')

        
        # Define preprocessing
        preprocess = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # Preprocess image
        input_tensor = preprocess(image)
        input_tensor = input_tensor.unsqueeze(0)  # Add batch dimension
        
        return image, input_tensor
    except Exception as e:
        print(f"Error processing image: {e}")
        return None, None

def segment_image(image_path):
    # Preprocess the image
    original_image, input_tensor = preprocess_image(image_path)
    
    if input_tensor is None:
        return None, None

    input_tensor = input_tensor.to(device)

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)['out'][0]
    
    # Post-process the output
    output_predictions = output.argmax(0).byte().cpu().numpy()

    return original_image, output_predictions

def show_result(image, segmentation_mask):
    if image is None or segmentation_mask is None:
        print("No valid image or segmentation mask to display.")
        return

    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Create a color map
    color_map = plt.cm.get_cmap('viridis')
    colored_mask = color_map(segmentation_mask / segmentation_mask.max())

    # Create subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    # Display original image
    ax1.imshow(image_np)
    ax1.set_title('Original Image')
    ax1.axis('off')

    # Display segmentation mask
    ax2.imshow(colored_mask)
    ax2.set_title('Segmentation Mask')
    ax2.axis('off')

    plt.tight_layout()
    plt.show()
    
# Example usage
image_path = '/kaggle/input/a-simple-dog/dog.png'

original_image, segmentation_mask = segment_image(image_path)

print(np.max(segmentation_mask)) 
print(np.min(segmentation_mask)) 

# Visualize the result
show_result(original_image, segmentation_mask)

・Output

the pretrained models are very useful. This can used for object detection or foreign detection. Please try it.

option

Adjust Threshold

・threshold in post_process_segmentation():
Lower values will make the segmentation more sensitive but may introduce more noise.
・min_size in post_process_segmentation():
This controls the minimum size of regions to keep. Smaller values will retain more detail but may include more noise.

But the adjustment is so difficult(like below).

・Adjust Threshold Code

import torch
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

# Load pre-trained DeepLabV3 model
model = deeplabv3_resnet101(pretrained=True)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def preprocess_image(image_path, target_size=(512, 512)):
    try:
        image = Image.open(image_path).convert('RGB')
        width, height = image.size
        print(f'width: {width}')
        print(f'height: {height}')
        
        preprocess = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        input_tensor = preprocess(image).unsqueeze(0)
        return image, input_tensor
    except Exception as e:
        print(f"Error processing image: {e}")
        return None, None

def segment_image(image_path):
    original_image, input_tensor = preprocess_image(image_path)
    if input_tensor is None:
        return None, None

    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        output = model(input_tensor)['out'][0]
    
    # Instead of argmax, we'll return the full output tensor
    output_probabilities = output.cpu().numpy()
    return original_image, output_probabilities

def post_process_segmentation(segmentation, threshold=0.5, min_size=100):
    # Apply threshold to create binary mask
    binary_mask = segmentation > threshold

    # Remove small objects
    binary_mask = ndimage.binary_opening(binary_mask)
    
    # Label connected components
    labeled, num_features = ndimage.label(binary_mask)
    
    # Remove small regions
    for i in range(1, num_features+1):
        if np.sum(labeled == i) < min_size:
            labeled[labeled == i] = 0
    
    return labeled

def show_result(image, segmentation):
    if image is None or segmentation is None:
        print("No valid image or segmentation to display.")
        return

    image_np = np.array(image)

    # Process each channel separately
    processed_segmentation = np.zeros_like(segmentation)
    for i in range(segmentation.shape[0]):
        processed_segmentation[i] = post_process_segmentation(segmentation[i])

    # Combine channels
    combined_segmentation = np.sum(processed_segmentation, axis=0)

    # Create a color map
    color_map = plt.cm.get_cmap('jet')  # 'jet' colormap for more color variety
    colored_mask = color_map(combined_segmentation / combined_segmentation.max())

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))

    ax1.imshow(image_np)
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(colored_mask)
    ax2.set_title('Segmentation Result')
    ax2.axis('off')

    plt.tight_layout()
    plt.show()

# Example usage
image_path = '/kaggle/input/a-simple-dog/dog.png'
original_image, segmentation = segment_image(image_path)
show_result(original_image, segmentation)

・Output

Reference

[1] Semantic Segmentation, papers with code

Discussion