【ML】Semantic Segmentation Sample Code
1. Semantic Segmentation
Semantic Segmentation is a computer vision task in which the goal is to categorize each pixel in an image into a class or object. The goal is to produce a dense pixel-wise segmentation map of an image, where each pixel is assigned to a specific class or object.
2. Code
Use a pre-trained model.
・Semantic Segmentation
import torch
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# Load pre-trained DeepLabV3 model
model = deeplabv3_resnet101(pretrained=True)
model.eval()
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def preprocess_image(image_path, target_size=(512, 512)):
try:
# Open image and convert to RGB (handles PNG, JPEG, etc.)
image = Image.open(image_path).convert('RGB')
width, height = image.size
print(f'width: {width}')
print(f'height: {height}')
# Define preprocessing
preprocess = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Preprocess image
input_tensor = preprocess(image)
input_tensor = input_tensor.unsqueeze(0) # Add batch dimension
return image, input_tensor
except Exception as e:
print(f"Error processing image: {e}")
return None, None
def segment_image(image_path):
# Preprocess the image
original_image, input_tensor = preprocess_image(image_path)
if input_tensor is None:
return None, None
input_tensor = input_tensor.to(device)
# Perform inference
with torch.no_grad():
output = model(input_tensor)['out'][0]
# Post-process the output
output_predictions = output.argmax(0).byte().cpu().numpy()
return original_image, output_predictions
def show_result(image, segmentation_mask):
if image is None or segmentation_mask is None:
print("No valid image or segmentation mask to display.")
return
# Convert PIL Image to numpy array
image_np = np.array(image)
# Create a color map
color_map = plt.cm.get_cmap('viridis')
colored_mask = color_map(segmentation_mask / segmentation_mask.max())
# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# Display original image
ax1.imshow(image_np)
ax1.set_title('Original Image')
ax1.axis('off')
# Display segmentation mask
ax2.imshow(colored_mask)
ax2.set_title('Segmentation Mask')
ax2.axis('off')
plt.tight_layout()
plt.show()
# Example usage
image_path = '/kaggle/input/a-simple-dog/dog.png'
original_image, segmentation_mask = segment_image(image_path)
print(np.max(segmentation_mask))
print(np.min(segmentation_mask))
# Visualize the result
show_result(original_image, segmentation_mask)
・Output
the pretrained models are very useful. This can used for object detection or foreign detection. Please try it.
option
Adjust Threshold
・threshold in post_process_segmentation():
Lower values will make the segmentation more sensitive but may introduce more noise.
・min_size in post_process_segmentation():
This controls the minimum size of regions to keep. Smaller values will retain more detail but may include more noise.
But the adjustment is so difficult(like below).
・Adjust Threshold Code
import torch
import torchvision
from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
# Load pre-trained DeepLabV3 model
model = deeplabv3_resnet101(pretrained=True)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def preprocess_image(image_path, target_size=(512, 512)):
try:
image = Image.open(image_path).convert('RGB')
width, height = image.size
print(f'width: {width}')
print(f'height: {height}')
preprocess = transforms.Compose([
transforms.Resize(target_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image).unsqueeze(0)
return image, input_tensor
except Exception as e:
print(f"Error processing image: {e}")
return None, None
def segment_image(image_path):
original_image, input_tensor = preprocess_image(image_path)
if input_tensor is None:
return None, None
input_tensor = input_tensor.to(device)
with torch.no_grad():
output = model(input_tensor)['out'][0]
# Instead of argmax, we'll return the full output tensor
output_probabilities = output.cpu().numpy()
return original_image, output_probabilities
def post_process_segmentation(segmentation, threshold=0.5, min_size=100):
# Apply threshold to create binary mask
binary_mask = segmentation > threshold
# Remove small objects
binary_mask = ndimage.binary_opening(binary_mask)
# Label connected components
labeled, num_features = ndimage.label(binary_mask)
# Remove small regions
for i in range(1, num_features+1):
if np.sum(labeled == i) < min_size:
labeled[labeled == i] = 0
return labeled
def show_result(image, segmentation):
if image is None or segmentation is None:
print("No valid image or segmentation to display.")
return
image_np = np.array(image)
# Process each channel separately
processed_segmentation = np.zeros_like(segmentation)
for i in range(segmentation.shape[0]):
processed_segmentation[i] = post_process_segmentation(segmentation[i])
# Combine channels
combined_segmentation = np.sum(processed_segmentation, axis=0)
# Create a color map
color_map = plt.cm.get_cmap('jet') # 'jet' colormap for more color variety
colored_mask = color_map(combined_segmentation / combined_segmentation.max())
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
ax1.imshow(image_np)
ax1.set_title('Original Image')
ax1.axis('off')
ax2.imshow(colored_mask)
ax2.set_title('Segmentation Result')
ax2.axis('off')
plt.tight_layout()
plt.show()
# Example usage
image_path = '/kaggle/input/a-simple-dog/dog.png'
original_image, segmentation = segment_image(image_path)
show_result(original_image, segmentation)
・Output
Reference
[1] Semantic Segmentation, papers with code
Discussion