Im2Col を用いた最大プーリングを MATLAB で解説
    
新井仁之(早稲田大学)
公開日 2024年11月18日
Ver. 1.2 2024年12月20日
Ver. 1.3  2025年1月1日
  
ここでは解説
「深層学習で使われる im2col を MATLABで解説」(http://www.araiweb.matrix.jp/Program/Im2Col_tutorial2.html)で述べた Im2Col を使った最大プーリングの方法について解説する.なお本ノートの MATLAB コードは,python コード([1], [2], [3],特に [2] )に基づいたものである.
まず最大プーリングとは,たとえば 2 x 2 の小ブロックの中の最大値のみを抜き出して,並べる操作のことである.元の画像が 2M x 2M  のとき,あるいは 2M+1 x 2M+1 のときは,M x M の画像となる.たとえば次のようにである.
もちろん別のブロックの分け方,スライドの仕方でもよい.
Pooling はたたみ込み層で行うが,たたみ込み層では入力画像数が複数ある.単純な具体例で操作がどのようなものかを見るため,次の例で考えていく.
この例では入力データは3個あるものとする.
この例では,入力データは 
[ Number of Filters, Number of Channels, Image_height, Image_width]
の4階テンソルであるとする.深層学習では,Channels は前層から受け渡されるデータの数を表し,Number of Filters は考えている層内のニューロンの個数を表す.
特に [3,2,6,6] の場合で見ていく.
まずこの層への仮想的な入力データを作る.
Image = zeros(NumFilters,NumChannels,6,6);
Image_org = Image; %保存のため
        disp(['Image(',num2str(i),',',num2str(j),',:,:) = '])
end
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
    35     1     6    26    19    24
     3    32     7    21    23    25
    31     9     2    22    27    20
     8    28    33    17    10    15
    30     5    34    12    14    16
     4    36    29    13    18    11
    35     1     6    26    19    24
     3    32     7    21    23    25
    31     9     2    22    27    20
     8    28    33    17    10    15
    30     5    34    12    14    16
     4    36    29    13    18    11
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
%このノートでは Padding = 0, Stride = 2 の場合に対応
Image = padarray(Image,[0 0 Padding Padding],0,'both');
%% 今の場合は Padding = 0 なので変化なし.
%         disp(['Image(',num2str(i),',',num2str(j),',:,:) =']);
%         squeeze(Image(i,j,:,:))
[NumImage, Channel, Image_height, Image_width] = size(Image)
NumImage = 3
Channel = 2
Image_height = 6
Image_width = 6
2 x 2 のブロック内で,最大をとる最大プーリングを考える.ブロックを設定する.
最大プーリングで出力される画像のサイズは,たとえば入力データが 2M x 2N ならば M x N になる.これは (2M-2)/2+1=M,(2N-2)/2+1=N である.奇数 2M+1 の場合も同様の値になる.実際,
(2M+1-2)/2+1= M-1+1=M
Output_height = fix((Image_height-block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-block_width+2*Padding)/Stride)+1;
disp(['[Output_height, output_width] = ','[', num2str(Output_height),',',num2str(Output_width),']']);
[Output_height, output_width] = [3,3]
Image の Im2Col は次のようなものになっている.
Col = Im2Col(Image,[block_height,block_width],Stride,Padding);
% [NumImage*Output_height*Output_width  Channel*block_height*block_width]
% = [3*3*3 x 2*2*2] = [27 8]
disp(Col);
     1     2     7     8     1     2     7     8
     3     4     9    10     3     4     9    10
     5     6    11    12     5     6    11    12
    13    14    19    20    13    14    19    20
    15    16    21    22    15    16    21    22
    17    18    23    24    17    18    23    24
    25    26    31    32    25    26    31    32
    27    28    33    34    27    28    33    34
    29    30    35    36    29    30    35    36
    35     1     3    32    35     1     3    32
     6    26     7    21     6    26     7    21
    19    24    23    25    19    24    23    25
    31     9     8    28    31     9     8    28
     2    22    33    17     2    22    33    17
    27    20    10    15    27    20    10    15
    30     5     4    36    30     5     4    36
    34    12    29    13    34    12    29    13
    14    16    18    11    14    16    18    11
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
     1     1     1     1     1     1     1     1
Col = permute(Col,[2 1]);
% [Channel*block_height*block_width NumImage*Output_height*Output_width]
disp(Col)
     1     3     5    13    15    17    25    27    29    35     6    19    31     2    27    30    34    14     1     1     1     1     1     1     1     1     1
     2     4     6    14    16    18    26    28    30     1    26    24     9    22    20     5    12    16     1     1     1     1     1     1     1     1     1
     7     9    11    19    21    23    31    33    35     3     7    23     8    33    10     4    29    18     1     1     1     1     1     1     1     1     1
     8    10    12    20    22    24    32    34    36    32    21    25    28    17    15    36    13    11     1     1     1     1     1     1     1     1     1
     1     3     5    13    15    17    25    27    29    35     6    19    31     2    27    30    34    14     1     1     1     1     1     1     1     1     1
     2     4     6    14    16    18    26    28    30     1    26    24     9    22    20     5    12    16     1     1     1     1     1     1     1     1     1
     7     9    11    19    21    23    31    33    35     3     7    23     8    33    10     4    29    18     1     1     1     1     1     1     1     1     1
     8    10    12    20    22    24    32    34    36    32    21    25    28    17    15    36    13    11     1     1     1     1     1     1     1     1     1
これを reshape して
(「block_height」*「block_width」)x  (「Channel*NumImage」*「Output_height」*「Output_width」)
にする:
Col = reshape(Col,[block_height*block_width Channel*NumImage*Output_height*Output_width]);
disp(Col)
     1     1     3     3     5     5    13    13    15    15    17    17    25    25    27    27    29    29    35    35     6     6    19    19    31    31     2     2    27    27    30    30    34    34    14    14     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1
     2     2     4     4     6     6    14    14    16    16    18    18    26    26    28    28    30    30     1     1    26    26    24    24     9     9    22    22    20    20     5     5    12    12    16    16     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1
     7     7     9     9    11    11    19    19    21    21    23    23    31    31    33    33    35    35     3     3     7     7    23    23     8     8    33    33    10    10     4     4    29    29    18    18     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1
     8     8    10    10    12    12    20    20    22    22    24    24    32    32    34    34    36    36    32    32    21    21    25    25    28    28    17    17    15    15    36    36    13    13    11    11     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1
Image の列に関する最大値を Max_values,最大値の位置を Max_positions とする.
[Max_values,Max_positions] = max(Col); 
disp(Max_values);
     8     8    10    10    12    12    20    20    22    22    24    24    32    32    34    34    36    36    35    35    26    26    25    25    31    31    33    33    27    27    36    36    34    34    18    18     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1
disp(Max_positions);
     4     4     4     4     4     4     4     4     4     4     4     4     4     4     4     4     4     4     1     1     2     2     4     4     1     1     3     3     1     1     4     4     1     1     3     3     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1     1
Pooling_Image = reshape(Max_values,[Channel Output_width Output_height NumImage]);
% この reshape については Appendix 2 を参照.
        disp(['Pooling_Image(',num2str(i),',:,:,',num2str(j),') =']);
        disp(squeeze(Pooling_Image(i,:,:,j)));
end
     8    20    32
    10    22    34
    12    24    36
    35    31    36
    26    33    34
    25    27    18
     8    20    32
    10    22    34
    12    24    36
    35    31    36
    26    33    34
    25    27    18
Pooling_Image = permute(Pooling_Image,[4 1 3 2]); 
%[NumImages Channel Output_height Output_width
% この reshape については Appendix 2 を参照.
この reshape の結果は:
        disp(['Pooling_Image(',num2str(i),',',num2str(j),',:,:) =']);
        squeeze(Pooling_Image(i,j,:,:))
end
     8    10    12
    20    22    24
    32    34    36
     8    10    12
    20    22    24
    32    34    36
    35    26    25
    31    33    27
    36    34    18
    35    26    25
    31    33    27
    36    34    18
以上の操作をまとめて関数ファイルとする.
  
最大プーリングの関数 Pooling Max
function [Pooling_Image,Max_positions] = Pooling_Max(Image,block,Stride,Padding)
% 入力タイプ Images(Number of Images, Number of Channels, Image_height, Image_width)
% 出力タイプ Pooling_image(Number of Images Number of Channels Output_height Output_width)
Image = padarray(Image,[0 0 Padding Padding],0,'both');
[NumImage, Channels, Image_height, Image_width] = size(Image);
Output_height = fix((Image_height - block_height + 2*Padding)/Stride)+1;
Output_width = fix((Image_width - block_width + 2*Padding)/Stride)+1;
Col = Im2Col(Image,[block_height,block_width],Stride,Padding);
Col = permute(Col,[2 1]);
Col = reshape(Col,[block_height*block_width Channels*NumImage*Output_height*Output_width]);
[Max_values,Max_positions] = max(Col);
Pooling_Image = reshape(Max_values,[Channels Output_height Output_width NumImage]);
Pooling_Image = permute(Pooling_Image,[4 1 3 2]); 
% Pooling image のサイズ:[NumImage Channel Output_height Output_width]
  
Max_Pooling を使って計算する
[Pooling_Image,Max_Positions] = Pooling_Max(Image_org,[2 2],2,0);
        disp(['Pooling_Image(',num2str(i),',',num2str(j),',:,:) =']);
        squeeze(Pooling_Image(i,j,:,:))
end
     8    10    12
    20    22    24
    32    34    36
     8    10    12
    20    22    24
    32    34    36
    35    26    25
    31    33    27
    36    34    18
    35    26    25
    31    33    27
    36    34    18
 
Appeddix 1
function Col = Im2Col(Image,block,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Filter, Channel, Image_height, Image_width)
% block = [number1, number2]
% Stride = number, Padding = number
% (NumImage*Output_height * Output_width) x (Channel * Block_height*Block_width)
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 次の [A], [B], [C], [D] の Pythonプログラムにもとづく.
% [A] https://docs.chainer.org/en/v7.8.1.post1/reference/generated/chainer.functions.im2col.html
% [B] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
% [C] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
% [D] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, Channel, Image_height, Image_width] = size(Image);
Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1;
col = zeros(NumImage,Block_height*Block_width,Channel,Output_hight,Output_width);
Image = padarray(Image,[0 0 Padding Padding],0,'both');
        % MATLAB の im2col の配列にするには次のようにする.
        %col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:,w:Stride:w-1+WS ,h:Stride:h-1+HS);
        % この部分は,Python などの他の文献のようにするには次のようにする.
         col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:, h:Stride:h-1+HS, w:Stride:w-1+WS);
% MATLAB式 im2col にするには次のようにする:
% Col = permute(col,[2 3 4 5 1]);
% この部分は,Python などの他の文献のようにするには次のようにする.
 Col = permute(col,[2 3 5 4 1]);
Col = reshape(Col, [Channel*Block_height*Block_width NumImage*Output_hight*Output_width ]);
Col = permute(Col,[2 1]);
 
Appendix 2 MATLAB の reshape について
disp(A)
     1     2     3     4     5     6     7     8     9    10    11    12    13    14    15    16    17    18    19    20    21    22    23    24    25    26    27    28    29    30    31    32    33    34    35    36    37    38    39    40    41    42    43    44    45    46    47    48    49    50    51    52    53    54
A = reshape(A,[2 3 3 3]); %[Channel Output_width Output_height NumImage]
        disp(['A(',num2str(i),',:,:,',num2str(j),') =']);
        disp(squeeze(A(i,:,:,j)));
end
    19    25    31
    21    27    33
    23    29    35
    37    43    49
    39    45    51
    41    47    53
    20    26    32
    22    28    34
    24    30    36
    38    44    50
    40    46    52
    42    48    54
A = permute(A,[4 1 3 2]);%[Channel Output_height Output_width NumImage]
        disp(['A(',num2str(i),',',num2str(j),',:,:) =']);
        disp(squeeze(A(i,j,:,:)));
end
    19    21    23
    25    27    29
    31    33    35
    20    22    24
    26    28    30
    32    34    36
    37    39    41
    43    45    47
    49    51    53
    38    40    42
    44    46    48
    50    52    54
  
参考文献
[1] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
[2] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
[3] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.