📷

CoreMLとVisionで画像識別

2022/04/06に公開

これはなに

CoreMLとVisionを使って画像の識別して遊んだ記録

とある勉強会でLTした内容です。
くるるんが可愛すぎるのでくるるんを題材にしていますが、個人的な利用であり、所属する会社、組織とは全く関係ありません。(念の為)

画像識別

カメラからimageBufferを利用して、画像識別してくるるんだったらシャッターを押せるというものをつくりました。

CoreMLを利用してみたかったので、AppleのcreateMLを使用したケースを考えます。

カメラの起動に関してはUIImagePickerViewControllerでは細かい制御が出来ないため、AVCaprureSessionを利用します。

設計

1.AVCaptureDeviceInputにvideoを設定
2.AVCaptureDeviceOutputのimageBufferに対してVNImageRequestHandlerを利用して、 imageBuffer から、検出すべき画像を配列取得します。
3.結果は VNCoreMLRequestに非同期で返す

CreateMLを使ってくるるん分類器を作る

まずは学習させるためのデータを準備します。
くるるんの画像データを集めますが、くるるんの画像だけではその他のものが写り込んだ時にCoreMLが判断できないのである程度他のデータも集めます。
今回は「くるるん」、「その他」、「こども」、「おとな」の画像を収集しました。

CoreMLに学習させる

「Xcode」->「Open Developer Tool」->「Create ML」からCreate MLをひらく

あとはgetのボタンから.mlを入手する

Visionとの連携

  1. リクエストを作成
  2. リクエストハンドラに画像を渡す
  3. Core MLにかける
  4. Observationがリクエストのクロージャに渡ってくる

MLファイルの読み込み

guard let modelURL = Bundle.main.url(
        forResource: "V2.KururunClassifier",
        withExtension: "mlmodelc"
    )
else {
    return NSError(
	..)
}

VNCoreMLRequestの作成

do {
    let visionModel = try VNCoreMLModel(for: MLModel(contentsOf: modelURL))
    let objectRecognition = VNCoreMLRequest(
        model: visionModel,
        completionHandler: { (request, _) in
            DispatchQueue.main.async(execute: {
                if let results = request.results {
                    self.mlCompletion(results)
                }
            })
        }
    )
    self.mlRequest = [objectRecognition]
} catch let error as NSError {
    // 略
}

requestHandlerに画像を渡す

public func captureOutput(_ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, from connection: AVCaptureConnection) {
    if connection.videoOrientation != .portrait {
      connection.videoOrientation = .portrait
      return
    }
    guard let buffer = CMSampleBufferGetImageBuffer(sampleBuffer) else { return }
    let ciImage = CIImage(cvImageBuffer: buffer)

    let exifOrientation = self.exifOrientationFromDeviceOrientation()
    let imageRequestHandler = VNImageRequestHandler(
        ciImage: ciImage,
        orientation: exifOrientation,
        options: [:]
    )
    do {
        try imageRequestHandler.perform(self.mlRequest)
    } catch {
        print(error)
    }
}

resultsの後処理。観測データが返ってくる

func mlCompletion(_ results: [Any]) {
    guard let observation = results.first as? VNClassificationObservation else {
        return
    }
    // results.firstは信用度の高い順に返ってくる
    print(observation.identifier, observation.confidence) // 結果は識別子と信用度

    identifier = observation.identifier
    confidence = floor(observation.confidence * 100)

    if observation.identifier == "kururun" {
   // 後続処理
    } else {
// 後続処理
    }
}

ソースコード

そのままでは動かないと思いますが雰囲気を感じていただくため貼っておきます

import AVFoundation
import Combine
import Foundation
import UIKit
import Vision

public class AVCaptureViewModel: NSObject, AVCapturePhotoCaptureDelegate, ObservableObject,
    AVCaptureVideoDataOutputSampleBufferDelegate, UnidirectionalDataFlowType
{

    // MARK: Input
    enum Input {
        case onCloseButtonTapped(Bool)
        case onSendImage(UIImage?)
    }

    func apply(_ input: Input) {
        switch input {
        case .onCloseButtonTapped(let tapped): onCloseButtonTappedSubject.send(tapped)
        case .onSendImage(let image): onSendImageSubject.send(image)
        }
    }

    public var captureSession: AVCaptureSession
    public var videoInput: AVCaptureDeviceInput!
    public var photoOutput: AVCapturePhotoOutput
    public var dataOutput: AVCaptureVideoDataOutput!

    private var cancellables: [AnyCancellable] = []

    private var mlRequest = [VNRequest]()

    @Published var identifier: String?
    @Published var confidence: Float?
    @Published var image: UIImage?
    @Published var resultsImage: String? = "sonota"
    @Published var resultsText: String?
    @Published var buttonDisabled = true
    @Published var closeButtonTapped = false
    @Published var sendImage: UIImage?

    private let onCloseButtonTappedSubject = PassthroughSubject<Bool, Never>()
    private let onSendImageSubject = PassthroughSubject<UIImage?, Never>()

    public override init() {

        self.captureSession = AVCaptureSession()
        self.photoOutput = AVCapturePhotoOutput()
        self.dataOutput = AVCaptureVideoDataOutput()

        super.init()

        bindOutputs()
    }

    func bindOutputs() {
        onCloseButtonTappedSubject
            .removeDuplicates()
            .assign(to: \.closeButtonTapped, on: self)
            .store(in: &cancellables)

        onSendImageSubject
            .assign(to: \.sendImage, on: self)
            .store(in: &cancellables)
    }

    func setupSession() {
        captureSession.beginConfiguration()
        guard let videoCaputureDevice = AVCaptureDevice.default(for: .video) else { return }

        guard let videoInput = try? AVCaptureDeviceInput(device: videoCaputureDevice) else {
            return
        }
        self.videoInput = videoInput
        guard captureSession.canAddInput(videoInput) else { return }
        captureSession.addInput(videoInput)

        guard captureSession.canAddOutput(photoOutput) else { return }
        captureSession.sessionPreset = .photo
        captureSession.addOutput(photoOutput)

        dataOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "camera"))
        captureSession.addOutput(dataOutput)

        captureSession.commitConfiguration()
        captureSession.startRunning()
    }

    func updateInputOrientation(orientation: UIDeviceOrientation) {
        for conn in captureSession.connections {
            conn.videoOrientation = ConvertUIDeviceOrientationToAVCaptureVideoOrientation(
                deviceOrientation: orientation
            )
        }
    }

    func takePhoto() {
        let photoSetting = AVCapturePhotoSettings()
        photoSetting.flashMode = .auto
        photoSetting.isHighResolutionPhotoEnabled = false
        photoOutput.capturePhoto(with: photoSetting, delegate: self)
        return
    }

    public func captureOutput(
        _ output: AVCaptureOutput,
        didOutput sampleBuffer: CMSampleBuffer,
        from connection: AVCaptureConnection
    ) {
        if connection.videoOrientation != .portrait {
            connection.videoOrientation = .portrait
            return
        }
        guard let buffer = CMSampleBufferGetImageBuffer(sampleBuffer) else { return }
        let ciImage = CIImage(cvImageBuffer: buffer)

        let exifOrientation = self.exifOrientationFromDeviceOrientation()
        let imageRequestHandler = VNImageRequestHandler(
            ciImage: ciImage,
            orientation: exifOrientation,
            options: [:]
        )
        do {
            try imageRequestHandler.perform(self.mlRequest)
        } catch {
            print(error)
        }
    }

    public func photoOutput(
        _ output: AVCapturePhotoOutput,
        didFinishProcessingPhoto photo: AVCapturePhoto,
        error: Error?
    ) {
        let imageData = photo.fileDataRepresentation()
        self.image = UIImage(data: imageData!)
    }

    func getImageFromSampleBuffer(sampleBuffer: CMSampleBuffer) -> UIImage? {
        guard let pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer) else {
            return nil
        }
        CVPixelBufferLockBaseAddress(pixelBuffer, .readOnly)
        let baseAddress = CVPixelBufferGetBaseAddress(pixelBuffer)
        let width = CVPixelBufferGetWidth(pixelBuffer)
        let height = CVPixelBufferGetHeight(pixelBuffer)
        let bytesPerRow = CVPixelBufferGetBytesPerRow(pixelBuffer)
        let colorSpace = CGColorSpaceCreateDeviceRGB()
        let bitmapInfo = CGBitmapInfo(
            rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue
                | CGBitmapInfo.byteOrder32Little.rawValue
        )
        guard
            let context = CGContext(
                data: baseAddress,
                width: width,
                height: height,
                bitsPerComponent: 8,
                bytesPerRow: bytesPerRow,
                space: colorSpace,
                bitmapInfo: bitmapInfo.rawValue
            )
        else {
            return nil
        }
        guard let cgImage = context.makeImage() else {
            return nil
        }
        let image = UIImage(cgImage: cgImage, scale: 1, orientation: .right)
        CVPixelBufferUnlockBaseAddress(pixelBuffer, .readOnly)
        return image
    }

    func setupVision() -> NSError? {
        let error: NSError! = nil

        guard
            let modelURL = Bundle.main.url(
                forResource: "V2.KururunClassifier",
                withExtension: "mlmodelc"
            )
        else {
            return NSError(
                domain: "VisionObjectRecognitionViewController",
                code: -1,
                userInfo: [NSLocalizedDescriptionKey: "Model file is missing"]
            )
        }
        do {
            let visionModel = try VNCoreMLModel(for: MLModel(contentsOf: modelURL))
            let objectRecognition = VNCoreMLRequest(
                model: visionModel,
                completionHandler: { (request, _) in
                    DispatchQueue.main.async(execute: {
                        if let results = request.results {
                            self.mlCompletion(results)
                        }
                    })
                }
            )
            self.mlRequest = [objectRecognition]
        } catch let error as NSError {
            print("(^θ^): \(error)")
        }
        return error
    }

    func mlCompletion(_ results: [Any]) {
        guard let observation = results.first as? VNClassificationObservation else {
            return
        }
        print(observation.identifier, observation.confidence)

        identifier = observation.identifier
        confidence = floor(observation.confidence * 100)

        if observation.identifier == "kururun" {
            resultsImage = "kururun"
            resultsText = "くるるんと一緒だね"
            buttonDisabled = false
        } else {
            resultsImage = "sonota"
            resultsText = "くるるんと写ってね"
            buttonDisabled = true
        }
    }

    func exifOrientationFromDeviceOrientation() -> CGImagePropertyOrientation {
        let curDeviceOrientation = UIDevice.current.orientation
        let exifOrientation: CGImagePropertyOrientation
        switch curDeviceOrientation {
        case UIDeviceOrientation.portraitUpsideDown:
            exifOrientation = .left
        case UIDeviceOrientation.landscapeLeft:
            exifOrientation = .upMirrored
        case UIDeviceOrientation.landscapeRight:
            exifOrientation = .down
        case UIDeviceOrientation.portrait:
            exifOrientation = .up
        default:
            exifOrientation = .up
        }
        return exifOrientation
    }
}

func ConvertUIDeviceOrientationToAVCaptureVideoOrientation(
    deviceOrientation: UIDeviceOrientation
) -> AVCaptureVideoOrientation {
    switch deviceOrientation {
    case .portrait:
        return .portrait
    case .portraitUpsideDown:
        return .portraitUpsideDown
    case .landscapeLeft:
        return .landscapeRight
    case .landscapeRight:
        return .landscapeLeft
    default:
        return .portrait
    }
}

今回は`hostingController経由でSwiftUIを表示しているのでHostingController

import Combine
import SwiftUI

class PhotoViewController: UIHostingController<
    ContentView
>
{

    private var cancellables: [AnyCancellable] = []

    var captureModel: AVCaptureViewModel!

    init() {
        let captureModel = AVCaptureViewModel()

        super
            .init(
                rootView: ContentView(
                    captureModel: captureModel
                )
            )
        self.captureModel = captureModel
    }

    @objc required dynamic init?(coder aDecoder: NSCoder) {
        fatalError("init(coder:) has not been implemented")
    }
}

UIViewRepresentableSwiftUIの画面にUIViewを表示している
プレビュー表示の部分

import AVFoundation
import Foundation
import SwiftUI

public class UIAVCaptureVideoPreviewView: UIView {
    var captureSession: AVCaptureSession!
    var previewLayer: AVCaptureVideoPreviewLayer!

    public init(frame: CGRect, session: AVCaptureSession) {
        self.captureSession = session
        super.init(frame: frame)
    }

    public required init?(coder: NSCoder) {
        super.init(coder: coder)
    }

    func setupPreview(previewSize: CGRect) {
        self.frame = previewSize
        self.previewLayer = AVCaptureVideoPreviewLayer(session: self.captureSession)
        self.previewLayer.frame = self.bounds
        self.updatePreviewOrientation()
        self.layer.addSublayer(previewLayer)
        self.captureSession.startRunning()
    }

    func updateFrame(frame: CGRect) {
        self.frame = frame
        self.previewLayer.frame = frame
    }

    func updatePreviewOrientation() {
        switch UIDevice.current.orientation {
        case .portrait:
            self.previewLayer.connection?.videoOrientation = .portrait
        case .portraitUpsideDown:
            self.previewLayer.connection?.videoOrientation = .portraitUpsideDown
        case .landscapeLeft:
            self.previewLayer.connection?.videoOrientation = .landscapeRight
        case .landscapeRight:
            self.previewLayer.connection?.videoOrientation = .landscapeLeft
        default:
            self.previewLayer.connection?.videoOrientation = .portrait
        }
        return
    }
}

public struct SwiftUIAVCaptureVideoPreviewView: UIViewRepresentable {
    let previewFrame: CGRect
    let captureModel: AVCaptureViewModel

    public func makeUIView(context: Context) -> UIAVCaptureVideoPreviewView {
        let view = UIAVCaptureVideoPreviewView(
            frame: previewFrame,
            session: self.captureModel.captureSession
        )
        view.setupPreview(previewSize: previewFrame)
        return view
    }

    public func updateUIView(_ uiView: UIAVCaptureVideoPreviewView, context: Context) {
        self.captureModel.updateInputOrientation(orientation: UIDevice.current.orientation)
        uiView.updateFrame(frame: previewFrame)
    }
}

SwiftUI側、ViewModelとの連携など

import CoreImage.CIFilterBuiltins
import SwiftUI
import Vision

struct ContentView: View {
    @ObservedObject var captureModel: AVCaptureViewModel
    @GestureState var scale: CGFloat = 1.0

    @State private var isShowingView: Bool = false

    var body: some View {
        captureModel.setupSession()
        captureModel.setupVision()
        return VStack {
            VStack {
                if let imageName = captureModel.resultsImage {
                    Image(imageName)
                        .resizable()
                        .scaledToFit()
                        .frame(height: 100)
                }
                if let resultText = captureModel.resultsText {
                    Text(resultText)
                        .foregroundColor(Color("TextPrimary"))
                        .font(.system(size: 14))
                        .padding(.top, 12)
                }
                if let confidence = captureModel.confidence,
                    let identifier = captureModel.identifier
                {
                    Text("\(identifier)の確率が\(confidence)%")
                        .foregroundColor(Color("Primary"))
                        .font(.system(size: 12))
                        .padding(.top, 12)
                }
            }

            GeometryReader { geometry in
                SwiftUIAVCaptureVideoPreviewView(
                    previewFrame: CGRect(
                        x: 0,
                        y: 0,
                        width: geometry.size.width,
                        height: geometry.size.height
                    ),
                    captureModel: captureModel
                )
            }
            HStack {

                if captureModel.image != nil {
                    VStack {
                        Image(uiImage: captureModel.image!)
                            .resizable()
                            .scaledToFit()
                            .frame(height: 100)
                            .onTapGesture {
                                isShowingView.toggle()
                            }
                            .sheet(isPresented: $isShowingView) {
                                if #available(iOS 14.0, *) {
                                    ImageFilterView(image: captureModel.image!)
                                } else {
                                    // 何もしない
                                }
                            }
                        Button(action: {
                            captureModel.apply(.onSendImage(captureModel.image!))
                        }) {
                            Text("この画像を使う")
                                .frame(maxWidth: 68)
                        }
                        .buttonStyle(
                            TertiaryButtonStyle()
                        )
                    }
                }

                Button(action: {
                    captureModel.takePhoto()
                }) {
                    Text("撮影する")
                        .frame(maxWidth: 108)
                }
                .disabled(captureModel.buttonDisabled)
                .buttonStyle(
                    PrimaryButtonStyle()
                )
            }
        }
        .conditionalNavigationBarItems(
            .close,
            action: { captureModel.apply(.onCloseButtonTapped(true)) },
            title: "くるるんといっしょキャンペーン"
        )
    }
}

Discussion