Im2Col を用いた畳み込み計算を MATLAB で解説
  
新井仁之(早稲田大学)
    
公開日 2024年11月18日 
Ver 1.4.1   2024年1月2日
  
  
CNN では多くの画像のたたみ込み積を計算する必要がある.深層学習では im2col を使って行列の掛け算にする計算が良く知られている.本ノートではこれを前回のチュートリアルで述べた MATLABで作成した Im2Col ([5])を用いて解説する.ここでは便宜上この方式によるたたみ込みを Cnv2 で表す.なお本ノートの MATLAB コードは,[3] のpython コード([2], [4] も参考)に基づいたものである.最後にMATLABの conv2 との比較をする.
1.入力画像例を設定
ここでは次の簡単な入力画像例で説明していく.
一般に画像ファイルは 
「チャネル数」x「画像高さ」x「画像幅」
の 3 次元配列でできている.ここでは Im2Col への入力が必要なため
入力型 = 「画像数」x「チャネル数」x「画像高さ」x「画像幅」
とするが,画像数 = 1 で議論を進める.
また,議論を単純にすることと,MATLAB の conv2 との比較をするため,チャネル数 1 と設定して話を進める.
NumImages = 1; % これは常に 1 に設定.Col2Im に入力するためのダミー次元.
ChannelImages = 1; % 入力画像数を指定.
Image_width =Image_height;
Image = zeros(NumImages,ChannelImages,Image_height,Image_width);
A1 = 1:Image_height * Image_width; 
A1 = reshape(A1,[Image_height Image_width])';
% ChannelImages が 2 以上の場合は,
% Image(1,i,:,:) (i=2,...,ChanelImages) も設定する.
        sq = squeeze(Image(i,j,:,:));
        disp(['Image(',num2str(i),',',num2str(j),',:,:) =']);
end 
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
% Image は後で変形するので,オリジナルを保存.後で使う.
2.フィルタとバイアスの設定
フィルタの型は
「NumFilters」x「ChannelFilters 」x「 Filter_height」x「Filter_width」 
の 4 次元配列で設定しておく.ニューラルネットでは通常 ChannelFilters = ChannelImages である.
次の例を考える.ここでは NumFilters = 1 で設定する.
ChannelFilters = ChannelImages; 
F_1 = [0 1 1; 1 0 2; 0 1 0]; % テストフィルタ
% F_1 = randi(3,3); %ランダムに設定
% F_1 = [1 0 0; 0 0 0 ;0 0 0]; % 単位インパルス
Filter = zeros(NumFilters,ChannelFilters,Filter_height,Filter_width);
Bias = 0; % Bias は 0 としておく.
        sq = squeeze(Filter(i,j,:,:));
        disp(['Filter(',num2str(i),',',num2str(j),',:,:) ='])
%Filter は後で変形するので,オリジナルを保存.後で使う.
3.Cnv2 の解説 (Im2Col を使った畳み込みの計算)
Stride と padding の設定をする.Padding の設定によって,MATLAB の次の型に対応する.
Padding = 0  ->  'valid' 型
Padding = 1 ->  'same' 型
Padding による画像拡張
Image = padarray(Image,[0 0 Padding Padding],0,'both');
Image データのパラメータ読み出し
[NumImages, ChannelImages,Image_height, Image_width] = size(Image)
NumImages = 1
ChannelImages = 1
Image_height = 6
Image_width = 6
Filter データのパラメータ読み出し
[NumFilters,ChannelFilters,Filter_height,Filter_width] = size(Filter)
NumFilters = 1
ChannelFilters = 1
Filter_height = 3
Filter_width = 3
たたみ込み演算による出力サイズ
Output_height = fix((Image_height - Filter_height + 2*Padding)/Stride)+1
Output_width = fix((Image_width - Filter_width + 2*Padding)/Stride)+1
たたみ込みを画像とフィルタの行列積で計算するため,画像データを Im2Col で並べ替える.Im2Col については前回の説明を参照.
ただし Im2Col は Python 仕様の配列方式に設定しておく(MATLAB仕様の配列方式,Pyhton仕様の配列方式については [5] を参照).
Col1 = Im2Col(Image,[Filter_height,Filter_width],Stride,Padding);
disp(['size(Col1) = [',num2str(sz),']'])
end
     1     7    13     2     8    14     3     9    15
     7    13    19     8    14    20     9    15    21
    13    19    25    14    20    26    15    21    27
    19    25    31    20    26    32    21    27    33
     2     8    14     3     9    15     4    10    16
     8    14    20     9    15    21    10    16    22
    14    20    26    15    21    27    16    22    28
    20    26    32    21    27    33    22    28    34
     3     9    15     4    10    16     5    11    17
     9    15    21    10    16    22    11    17    23
    15    21    27    16    22    28    17    23    29
    21    27    33    22    28    34    23    29    35
     4    10    16     5    11    17     6    12    18
    10    16    22    11    17    23    12    18    24
    16    22    28    17    23    29    18    24    30
    22    28    34    23    29    35    24    30    36
行列積にするためにフィルタ・データを並び変える
Filter = permute(Filter,[3 4 2 1]);
%Filter = permute(Filter,[4 3 2 1]);
Filter = reshape(Filter,[ChannelFilters*Filter_height*Filter_width NumFilters])';
disp(['size(Filter) = [', num2str(sz),']'])
行列積をするため,これを転置する.
Filter = permute(Filter,[2 1]);
画像データとこのフィルタおよびバイアスによるアフィン積をとる.これは affine_product として関数ファイルを作っておく(下記参照).今,Bias = 0 としているので,通常の畳み込み演算
Conv_out = affine_product(Col1,Filter,Bias);
disp(['size(Filter) = [', num2str(sz),']'])
Conv_out = permute(Conv_out,[2 1]);
disp(Conv_out)
    44    80   116   152    50    86   122   158    56    92   128   164    62    98   134   170
Conv_out = reshape(Conv_out,[NumFilters Output_width Output_height NumImages]);
%% Im2Col をMATLABの配列方法に設定した場合.
 Conv_out = permute(Conv_out,[4 1 2 3]);
 %% Im2Col を Python 仕様の配列に設定した場合
% Conv_out = permute(Conv_out,[4 1 3 2]);
    sq = squeeze(Conv_out(i,j,:,:));
    disp(['Conv_out(' num2str(i) ',' num2str(j) ',:,:) の表示'])
end
    44    50    56    62
    80    86    92    98
   116   122   128   134
   152   158   164   170
これで Cnv2 が完成.
4.以上のコードをまとめて関数ファイル Cnv2 を作る
function Z = Cnv2(Image,Filter,Bias,Stride,Padding)
% Stride = 1, Padding = 0 -> 'valid'型畳み込み
% Stride = 1, Padding = 1 -> 'same'型畳み込み
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Images, Channel, Image_height, Image_width)
% Filter = Filter(Number of Filters, Channel, Filter_height, Filter_width)
% Number of Images x Number of Filters x output_hight x output_width
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 本ノートの参考文献 [2], [3], [4] の Pythonプログラムにもとづく.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, ~, Image_height, Image_width] = size(Image);
[NumFilter,Channel,Filter_height,Filter_width] = size(Filter);
Output_height = fix((Image_height -Filter_height +2*Padding)/Stride)+1;
Output_width = fix((Image_width -Filter_width +2*Padding)/Stride)+1;
Image = Im2Col(Image,[Filter_height,Filter_width],Stride,Padding);
% Im2Col が Python 仕様の配列の場合
%Filter = permute(Filter,[4 3 2 1]);
%% Im2Col が MATLAB 仕様の配列方式の場合
Filter = permute(Filter,[3 4 2 1]);
Filter = reshape(Filter,[Channel*Filter_height*Filter_width NumFilter])';
Filter = permute(Filter,[2 1]);
Z = affine_product(Image,Filter,Bias);
Z = reshape(Z,[NumFilter Output_width Output_height NumImage]);
% Z = permute(Z,[4 1 3 2]);
Z = permute(Z,[4 1 2 3]);
    if mod(Filter_height,2) == 0
        Z = Z(:,:,2:Output_height,:);
    if mod(Filter_width,2) == 0
        Z = Z(:,:,:,2:Output_width);
function C = affine_product(A,B,b)
% 画像 A とファイルタ B,そしてバイアス = 実数とのアフィン積
5.Cnv2 を使って計算してみる
早速 Cnv2 を使った計算で検算をする.
Conv_out = Cnv2(Image_org,Filter_org,Bias,Stride,Padding);
    sq = squeeze(Conv_out(i,j,:,:));
    disp(['Conv_out(' num2str(i) ',' num2str(j) ',:,:) の表示'])
end
    44    50    56    62
    80    86    92    98
   116   122   128   134
   152   158   164   170
6.MATLAB の conv2 との比較
MATLAB に conv2 がある.それとここでの相関積(深層学習分野ではこれを畳み込みと読んでいる)との比較をする.
conv2 への入力画像データは,最初に設定した画像データと同じである.
A(:,:) = Image_org(1,1,:,:);
disp(A)
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
フィルタも既に定めたフィルタと同じものとする.
F = zeros(Filter_height,Filter_width);
F(:,:) = Filter_org(1,1,:,:);
まず,MATLAB の畳み込みの valid を計算する.MATLAB の畳み込みの定義は,ここでの畳み込み(じつは相関積)とは異なる定義なので,結果も異なる.
Z_M = conv2(A,F,"valid");
Z_M(:,:)
    52    58    64    70
    88    94   100   106
   124   130   136   142
   160   166   172   178
しかし,MATLAB の conv2 でフィルタを 180度回転させると,相関積となる.すなわち
Z_M2 = conv2(A,F_rot,"valid");
Z_M2(:,:)
    44    50    56    62
    80    86    92    98
   116   122   128   134
   152   158   164   170
Z = Cnv2(Image_org, Filter_org, Bias, Stride, Padding);
squeeze(Z(1,1,:,:))
    44    50    56    62
    80    86    92    98
   116   122   128   134
   152   158   164   170
もう少し大きな1枚の画像データに対する conv2 との計算時間を比較する.
TestImage = zeros(1,1,M,M);
MATLAB の conv2
Z_M3 = conv2(A,F_rot,"valid");
Cnv2   - Ver 1.3 の訂正 - 
Z = Cnv2(TestImage, Filter_org, Bias, Stride, Padding);
Appendix 1    Im2Col
function Col = Im2Col(Image,block,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Filter, Channel, Image_height, Image_width)
% block = [number1, number2]
% Stride = number, Padding = number
% (NumImage*Output_height * Output_width) x (Channel * Block_height*Block_width)
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 参考文献 [1], [2],[3], [4] の Pythonプログラムにもとづく.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, Channel, Image_height, Image_width] = size(Image);
Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1;
col = zeros(NumImage,Block_height*Block_width,Channel,Output_hight,Output_width);
Image = padarray(Image,[0 0 Padding Padding],0,'both');
        % MATLAB の im2col の配列にするには次のようにする.
         col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:,w:Stride:w-1+WS ,h:Stride:h-1+HS);
        % この部分は,Python などの他の文献のようにするには次のようにする.
        % col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:, h:Stride:h-1+HS, w:Stride:w-1+WS);
% MATLAB式 im2col にするには次のようにする:
 Col = permute(col,[2 3 4 5 1]);
% この部分は,Python などの他の文献のようにするには次のようにする.
% Col = permute(col,[2 3 5 4 1]);
Col = reshape(Col, [Channel*Block_height*Block_width NumImage*Output_hight*Output_width ]);
Col = permute(Col,[2 1]);
    
参考文献
[1] https://docs.chainer.org/en/v7.8.1.post1/reference/generated/chainer.functions.im2col.html
[2] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
[3] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
[4] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
[5] 新井仁之,深層学習で使われる im2col を MATLAB で解説,http://www.araiweb.matrix.jp/Program/Im2Col_tutorial2.html
  
履歴          
Ver 1.4.1   2024年1月2日
Ver 1.4  2024年12月30日
Ver 1.3  2024年12月30日
Ver.1.2    2024年12月8日
 Ver 1.1 2024年11月18日 
  
  
Copyright © Hitoshi Arai, 2024