Pythonで実装する画像認識アルゴリズム SLIC 入門

こんにちは。データサイエンスチーム tmtkです。
この記事では、SLIC (Simple Linear Iterative Clustering) を紹介します。紹介にあたって、私がPython 3で実装したものを使って解説していきます。
今回処理する画像
(今回処理する画像。choco.jpg として保存)

SLICとは

SLIC (Simple Linear Iterative Clustering) とは、画像認識にかかわるアルゴリズムのひとつです。Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunkらが2010年に発明したようです(文献[1, 3])。
画像は画素(pixel)が縦横に並んでできています。superpixelは、距離的・色的に近い画素をひとまとまりにとらえたものです。画像認識の前処理としてsuperpixelを計算しておくことで、画像の情報量を上手に減らし、他の画像認識アルゴリズムが適用しやすくなるようです。
superpixel化された画像
(superpixel化された画像)
SLICは画像を入力として、superpixelへの分割を出力とするアルゴリズムです。分割したいsuperpixelの数と画像を指定してSLICで処理すると、入力画像のsuperpixelへの分割を得ることができます。
superpixelの計算アルゴリズムの中で、SLICは領域境界への追従性、時間計算量、空間計算量などの点で他のアルゴリズムに負けず劣らず優れていることが文献[2]で示されています。

SLICのアルゴリズムの概略

SLICのアルゴリズムはk平均法を基にしており、それにいくつかのアイデアを加えて改良されています。SLICのアルゴリズムの特徴を列挙すると、以下のようになります。

  1. 画像を読み込み、グレースケールやRGB(r, g, b)であらわされている値をLab色空間(l, a, b)に変換する
  2. 座標(x, y)にある色(l, a, b)の画素の特徴量として、5次元ユークリッド空間上の点(l, a, b, x, y)を使う
  3. 手順2.で得た画素の特徴量の空間に、k平均法の亜種を適用し、クラスタリングを行う

SLICのアルゴリズムの詳細

以下で、SLICのアルゴリズムの詳細を解説していきます。

実装に、以下のライブラリを使います。Pythonの画像処理ライブラリscikit-imageには既にSLICが実装されていますが、ここではscikit-imageのSLICは使わず、画像の読み込みとLab色空間への変換にのみscikit-imageを使うことにします。

import sys, math
import numpy as np
from skimage import io, color

SLICクラスを定義します。以下でメソッドを定義していきます。

class SLIC:

コンストラクタで各種パラメータを定義します。kは計算するsuperpixel(クラスタ)の数、mは画素(l, a, b, x, y)の距離を計算するとき色成分(l, a, b)に比べて近さ成分(x, y)をどれだけ重視するかを決めるパラメータです。文献[2]では1\leq m \leq 40で決めるよう書かれているので、デフォルトでm=20とすることにします。

    def __init__(self, k, m = 20):
        """ Constructor.

        k: the number of superpixels.
        m: a parameter to weigh the relative importance of spatial proximity.
        """
        self.k = k
        self.m = m
        self.iter_max = 10 # c.f. the paper.

SLICのアルゴリズムは、大きく初期化と繰り返し処理の二つに分かれています。初期化処理をfit_init()、繰り返し処理をfit_iter()として、それぞれ実装していくことにします。

    def fit(self, img_path):
        """ Calculate superpixels.

        Returns the mask array.
        """
        self.fit_init(img_path)
        self.fit_iter()
        return self.l

初期化処理では、

  1. 画像をLab色空間に変換する
  2. 位置(x, y)にある色(l, a, b)の画素を座標(l, a, b, x, y)の点とみなす
    (コンピュータで処理する都合上、座標は(x, y)ではなく(height, width)という逆転した順番になっています)
  3. k個のクラスタの中心を等間隔に初期化する
    (文献[2]では、物体のふちの部分にクラスタの中心点を置くことをさけるために、等間隔でおいたクラスタの中心の周囲3×3画素も見て、その中で勾配(文献[3]を参照)が一番小さい画素にクラスタの中心をおきなおすと説明されていますが、ここではその処理は省略しています)
  4. i番目の点と最寄のクラスタの中心の距離d[i]\inftyに初期化する
  5. クラスタの直径の近似値S=\sqrt{N/k}Nは画素数)を計算しておく

などの処理を行います。

    def fit_init(self, img_path):
        """
        Read the image from img_path,
        convert to Lab color space,
        and initialize cluster centers.
        """

        img_rgb = io.imread(img_path)
        if img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
            raise Exception("Non RGB file. The shape was {}.".format(img_rgb.shape))
        img_lab = color.rgb2lab(img_rgb)

        self.height = img_lab.shape[0]
        self.width = img_lab.shape[1]

        self.pixels = []
        for h in range(self.height):
            for w in range(self.width):
                self.pixels.append(np.array([img_lab[h][w][0], img_lab[h][w][1], img_lab[h][w][2], h, w]))
        self.size = len(self.pixels)

        # Initialize cluster centers to be regularly spaced.
        self.cluster_center = []
        k_w = int(math.sqrt(self.k * self.width / self.height)) + 1
        k_h = int(math.sqrt(self.k * self.height / self.width)) + 1
        for h_cnt in range(k_h):
            h = (2 * h_cnt + 1) * self.height // (2 * k_h)
            for w_cnt in range(k_w):
                w = (2 * w_cnt + 1) * self.width // (2 * k_w)
                self.cluster_center.append(self.pixels[h*self.width + w])
        self.k = k_w*k_h



        self.l = [None] * self.size # The cluster labels
        self.d = [math.inf] * self.size # The distance between a pixel and the nearest cluster center
        self.S = int(math.sqrt(self.size/self.k)) # The approximate distance between cluster centers
        self.metric = np.diagflat([1/(self.m**2)]*3 +  [1/(self.S**2)]*2)

繰り返し処理では、k平均法の類似を行います。
各クラスタの中心(l_j, a_j, b_j, x_j, y_j)ごとに、i番目の点(l_i, a_i, b_i, x_i, y_i)との距離を計算します。その際、中心からx, yの差がS以下(つまりx_j-S\leq x_i\leq x_j+S, y_j - S \leq y_i \leq y_j + S)の点のみについて距離を計算します。そして、その距離が記録されている最小距離d[i]より小さかった場合、i番目の点はj番目のクラスタに所属する、と更新します。おおむね通常のk平均法と同じですが、クラスタの中心との距離を計算する点が、クラスタの中心の位置と画素の位置が近い点のみに絞られていることが特徴です。
所属するクラスタを更新したら、クラスタの中心も更新します(calc_new_center())。
また、距離の定義も通常のユークリッド距離とは別の距離を用います(distance())。点(l_1, a_1, b_1, x_1, y_1)と点(l_2, a_2, b_2, x_2, y_2)の距離は、\sqrt{\frac{(l_1-l_2)^2+(a_1-a_2)^2+(b_1-b_2)^2}{m^2}+\frac{(x_1-x_2)^2+(y_1-y_2)^2}{S^2}}と定義します。これは、色の尺度(l, a, b)と位置の尺度(x, y)のスケールが異なるためです。
文献[2]によると、このk平均法の類似の繰り返し回数は、10回程度で十分であることが経験的に知られているそうです。

    def fit_iter(self):
        """ Iteration step.
        """
        for iter_cnt in range(self.iter_max):
            for center_idx, center in enumerate(self.cluster_center):
                for h in range(max(0, int(center[3])-self.S), min(self.height, int(center[3])+self.S)):
                    for w in range(max(0, int(center[4])-self.S), min(self.width, int(center[4])+self.S)):
                        d = self.distance(self.pixels[h*self.width + w], center)
                        if d < self.d[h*self.width + w]:
                            self.d[h*self.width + w] = d
                            self.l[h*self.width + w] = center_idx
            self.calc_new_center()

    def distance(self, x, y):
        """ Squared distance between x and y.
        """
        return (x-y).dot(self.metric).dot(x-y)

    def calc_new_center(self):
        """ Caluclate new cluster centers.
        """
        cnt = [0] * self.k
        new_cluster_center = [np.array([0., 0., 0., 0. ,0.]) for _ in range(self.k)]
        for i in range(self.size):
            new_cluster_center[self.l[i]] += self.pixels[i]
            cnt[self.l[i]] += 1
        for i in range(self.k):
            new_cluster_center[i] /= cnt[i]
        self.cluster_center = new_cluster_center

ここまでの実装で、 SLIC(k=100).fit("choco.jpg") などとすると座標(x, y)ごとに何番目のsuperpixelに所属するのかのラベルの配列が計算できるようになりました。

superpixelごとにsuperpixelに属する画素の色の平均を計算し、RGBに変換してから返すメソッドも実装します。

    def transform(self):
        """ Returns new image RGB ndarray """
        cnt = [0] * self.k
        cluster_color = [np.array([0., 0., 0.]) for _ in range(self.k)]
        for i in range(self.size):
            cluster_color[self.l[i]] += self.pixels[i][:3]
            cnt[self.l[i]] += 1
        for i in range(self.k):
            cluster_color[i] /= cnt[i]
        new_img_lab = np.zeros((self.height, self.width, 3))
        for h in range(self.height):
            for w in range(self.width):
                new_img_lab[h][w] = cluster_color[self.l[h*self.width + w]]
        return color.lab2rgb(new_img_lab)

ここまでの実装をchoco.jpgにかけ、変換してみます。私の環境で60秒くらいかかります。

    slic = SLIC(k = 100)
    slic.fit("choco.jpg")
    res = slic.transform()
    io.imshow(res)

choco.jpgにここまでのSLICをかけたもの
写っている物体の境界にうまく追従していることがわかると思います。

superpixelを連結にするため、文献[1]の実装では、孤立したsuperpixelを近くの大きなsuperpixelに併合する後処理が追加されていますが、この記事では省略します。scikit-imageで実装されているSLICによる結果は以下のとおりです。こちらはCythonで実装されているため、1秒程度で終わります。

from skimage import io, segmentation, color
img = io.imread("choco.jpg")
label = segmentation.slic(img, compactness=20)
out = color.label2rgb(label, img, kind = 'avg')
io.imsave("lena_skimage.png", out)

choco.jpgにscikit-imageのSLICをかけたもの
scikit-imageでは孤立したsuperpixelを併合する後処理が追加されているため、イチゴの部分などに孤立したsuperpixelがありません。

さらなる話題

パラメータmを自動的に決めるSLICOという手法もあります(文献[1, 2])。

まとめ

画像認識の前処理に使われるsuperpixelを計算するアルゴリズムのひとつであるSLICの紹介・解説をしました。SLICはk平均法を応用したアルゴリズムです。

SLIC クラスの全コード

"""
SLIC implementation in Python 3
"""

import sys, math
import numpy as np
from skimage import io, color


class SLIC:
    def __init__(self, k, m = 20):
        """ Constructor.
        
        k: the number of superpixels.
        m: a parameter to weigh the relative importance of spatial proximity.
        """
        self.k = k
        self.m = m
        self.iter_max = 10 # c.f. the paper.

    def fit(self, img_path):
        """ Calculate superpixels.
        
        Returns the mask array.
        """
        self.fit_init(img_path)
        self.fit_iter()
        return self.l

    def fit_init(self, img_path):
        """
        Read the image from img_path,
        convert to Lab color space,
        and initialize cluster centers.
        """
        
        img_rgb = io.imread(img_path)
        if img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
            raise Exception("Non RGB file. The shape was {}.".format(img_rgb.shape))
        img_lab = color.rgb2lab(img_rgb)
        
        self.height = img_lab.shape[0]
        self.width = img_lab.shape[1]
        
        self.pixels = []
        for h in range(self.height):
            for w in range(self.width):
                self.pixels.append(np.array([img_lab[h][w][0], img_lab[h][w][1], img_lab[h][w][2], h, w]))
        self.size = len(self.pixels)

        # Initialize cluster centers to be regularly spaced.
        self.cluster_center = []
        k_w = int(math.sqrt(self.k * self.width / self.height)) + 1
        k_h = int(math.sqrt(self.k * self.height / self.width)) + 1
        for h_cnt in range(k_h):
            h = (2 * h_cnt + 1) * self.height // (2 * k_h)
            for w_cnt in range(k_w):
                w = (2 * w_cnt + 1) * self.width // (2 * k_w)
                self.cluster_center.append(self.pixels[h*self.width + w])
        self.k = k_w*k_h



        self.l = [None] * self.size # The cluster labels
        self.d = [math.inf] * self.size # The distance between a pixel and the nearest cluster center
        self.S = int(math.sqrt(self.size/self.k)) # The approximate distance between cluster centers
        self.metric = np.diagflat([1/(self.m**2)]*3 +  [1/(self.S**2)]*2)

    def fit_iter(self):
        """ Iteration step.
        """
        for iter_cnt in range(self.iter_max):
            for center_idx, center in enumerate(self.cluster_center):
                for h in range(max(0, int(center[3])-self.S), min(self.height, int(center[3])+self.S)):
                    for w in range(max(0, int(center[4])-self.S), min(self.width, int(center[4])+self.S)):
                        d = self.distance(self.pixels[h*self.width + w], center)
                        if d < self.d[h*self.width + w]:
                            self.d[h*self.width + w] = d
                            self.l[h*self.width + w] = center_idx
            self.calc_new_center()

    def distance(self, x, y):
        return (x-y).dot(self.metric).dot(x-y)
        self.iter_max = 10 # c.f. the paper.

    def fit(self, img_path):
        """ Calculate superpixels.
        
        Returns the mask array.
        """
        self.fit_init(img_path)
        self.fit_iter()
        return self.l

    def fit_init(self, img_path):
        """
        Read the image from img_path,
        convert to Lab color space,
        and initialize cluster centers.
        """
        
        img_rgb = io.imread(img_path)
        if img_rgb.ndim != 3 or img_rgb.shape[2] != 3:
            raise Exception("Non RGB file. The shape was {}.".format(img_rgb.shape))
        img_lab = color.rgb2lab(img_rgb)
        
        self.height = img_lab.shape[0]
        self.width = img_lab.shape[1]
        
        self.pixels = []
        for h in range(self.height):
            for w in range(self.width):
                self.pixels.append(np.array([img_lab[h][w][0], img_lab[h][w][1], img_lab[h][w][2], h, w]))
        self.size = len(self.pixels)

        # Initialize cluster centers to be regularly spaced.
        self.cluster_center = []
        k_w = int(math.sqrt(self.k * self.width / self.height)) + 1
        k_h = int(math.sqrt(self.k * self.height / self.width)) + 1
        for h_cnt in range(k_h):
            h = (2 * h_cnt + 1) * self.height // (2 * k_h)
            for w_cnt in range(k_w):
                w = (2 * w_cnt + 1) * self.width // (2 * k_w)
                self.cluster_center.append(self.pixels[h*self.width + w])
        self.k = k_w*k_h



        self.l = [None] * self.size # The cluster labels
        self.d = [math.inf] * self.size # The distance between a pixel and the nearest cluster center
        self.S = int(math.sqrt(self.size/self.k)) # The approximate distance between cluster centers
        self.metric = np.diagflat([1/(self.m**2)]*3 +  [1/(self.S**2)]*2)

    def fit_iter(self):
        """ Iteration step.
        """
        for iter_cnt in range(self.iter_max):
            for center_idx, center in enumerate(self.cluster_center):
                for h in range(max(0, int(center[3])-self.S), min(self.height, int(center[3])+self.S)):
                    for w in range(max(0, int(center[4])-self.S), min(self.width, int(center[4])+self.S)):
                        d = self.distance(self.pixels[h*self.width + w], center)
                        if d < self.d[h*self.width + w]:
                            self.d[h*self.width + w] = d
                            self.l[h*self.width + w] = center_idx
            self.calc_new_center()

    def distance(self, x, y):
        """ Squared distance between x and y.
        """
        return (x-y).dot(self.metric).dot(x-y)

    def calc_new_center(self):
        """ Caluclate new cluster centers.
        """
        cnt = [0] * self.k
        new_cluster_center = [np.array([0., 0., 0., 0. ,0.]) for _ in range(self.k)]
        for i in range(self.size):
            new_cluster_center[self.l[i]] += self.pixels[i]
            cnt[self.l[i]] += 1
        for i in range(self.k):
            new_cluster_center[i] /= cnt[i]
        self.cluster_center = new_cluster_center

    def transform(self):
        """ Returns new image RGB ndarray """
        cnt = [0] * self.k
        cluster_color = [np.array([0., 0., 0.]) for _ in range(self.k)]
        for i in range(self.size):
            cluster_color[self.l[i]] += self.pixels[i][:3]
            cnt[self.l[i]] += 1
        for i in range(self.k):
            cluster_color[i] /= cnt[i]
        new_img_lab = np.zeros((self.height, self.width, 3))
        for h in range(self.height):
            for w in range(self.width):
                new_img_lab[h][w] = cluster_color[self.l[h*self.width + w]]
        return color.lab2rgb(new_img_lab)

文献

  1. Superpixel segmentation | IVRL
  2. Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels Compared to State-of-the-art Superpixel Methods, IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 34, num. 11, p. 2274 – 2282, May 2012.
  3. Radhakrishna Achanta, Appu Shaji, Kevin Smith, Aurelien Lucchi, Pascal Fua, and Sabine Süsstrunk, SLIC Superpixels, EPFL Technical Report no. 149300, June 2010.
  4. k平均法 – Wikipedia
  5. Lab色空間 – Wikipedia
  6. バレンタインのチョコレートケーキを焼く女性|ぱくたそフリー素材
  7. scikit-image: Image processing in Python — scikit-image
  8. Normalized Cut — skimage v0.14dev docs
AWS移行支援キャンペーン

あなたにおすすめの記事