col2im,逆最大プーリングを MATLAB で解説
  
新井仁之(早稲田大学)
公開日 2025年1月1日
  
このノートでは,ディープラーニングで良く知られた col2im と unpooling (逆最大プーリング) について,丁寧に解説することが目的である.なおここでは MATLAB を使う.
なお,Im2Col についてはすでに
「深層学習で使われる im2col を MATLAB で解説」(http://www.araiweb.matrix.jp/Program/Im2Col_tutorial2.html)
で解説してあり,重複する部分もあるが,簡単な復習から始める(第1節).次に col2im について述べる.
ここでの im2col と col2im を,Python の良く知られた im2col,col2im,および MATLAB に内装されている im2col,col2im と区別するために,Im2Col,Col2Im で表す.
なお以下に記す Im2Col と Col2Im は Python のコード([1], [2], [3],とりわけ [2])を元に MATLAB で作成したものである.
 
1.Im2Col と畳み込みについて(復習)
 
1.1  im2col (ここでは Im2Col)はなぜ必要か.
  
画像 X のフィルタ F による畳み込み積(数学分野における相関積)は次のように定義される.
を畳み込み積(数学分野では相関積)という.なお数学分野で畳み込み積は  を
 を  とする.
 とする.  の範囲は畳み込みの種類によって変わる.MATLABでは valid, same, full があるが,ここでは記述を簡単のにするため,画像のゼロパディングによる拡張をしない valid 方式を使う.この場合は,
 の範囲は畳み込みの種類によって変わる.MATLABでは valid, same, full があるが,ここでは記述を簡単のにするため,画像のゼロパディングによる拡張をしない valid 方式を使う.この場合は, 
 とする.したがって畳み込みの後の画像は小さくなる.
im2col(Im2Col)はこの畳み込み演算を,画像とフィルタの配列を変えて行列の積で実現するためのものである.
そのためまず畳み込み積の具体例を見ておこう.
 
1.2 畳み込み積の具体例
  
簡単な一例を示しておく.画像としては次のものを考える.
Image_width = Image_height;
Image = zeros(NumImages,Channel_Image,Image_height,Image_width);
Image1 = 1:Image_height*Image_width;
Image1 = reshape(Image1, [Image_width,Image_height])';
    sq = squeeze(Image(i,j,:,:));
    disp(['Image(',num2str(1),',',num2str(1),',:,:) = ']);
end
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
フィルタは次のものである.
Filter_width = Filter_height;
Filter = zeros(1,1,Filter_height,Filter_width);
        Filter(i,j,:,:) = [0 1 0; 1 1 1; 0 1 0]; %Laplacian Filter
        sq = squeeze(Filter(i,j,:,:));
        disp(['Filter(',num2str(i),',',num2str(j),';:,:) = ']);
このとき,畳み込みは画像をフィルタサイズにブロック分けして,フィルタの各成分を掛け合わせるという操作をする.イメージ図としては次のようになる.ただし MATLAB仕様ではなく Python 仕様の配列.
まず一つ目のブロック(赤点線で囲った部分)とフィルタの各成分の積の和,つまり内積: 
次に二つ目のブロック(青点線で囲った部分)とフィルタの内積をとる:
この計算を続けて,五つ目のブロック(緑点線で囲った部分)とフィルタとの内積をする:
他も同様に計算すると結果はつぎのようになる.
さてここで,このブロック行列ごとの計算を,一発の行列の積で計算するためためにはどうすればよいか.
たとえば一つ目のブロックの計算であるが,これはブロックに区切った配列の成分を1列にならば,一方,フィルタの方も一列にならべて各成分の掛け算の和を考えればよい.つまり
とすれば,一つ目のブロックと同じ計算をすることになる.
また,二つ目のブロックの計算は
である.これらをまとめて計算するには
画像データをブロックごとに取り出し,一列に横に並べ,それらの結果を縦にならべた配列をつくる.それにフィルタを縦配列に直し,行列としての積をとればよい.
この操作のうち画像データの配列変更をプログラム化したものが Im2Col である.Im2Colの内容の説明は既に行ったので,ここでは計算結果のみを示しておく.stride = 1 で畳み込みをおこなっていくので,画像サイズが 6 x 6,ブロックサイズ 3 x 3 であるから,全部で 
行必要になる.つまり 16 x 9 の配列になる.Im2Col はこの配列を作ってくれる:
block = [size(Filter,3), size(Filter,4)];
Col = Im2Col(Image,block,Stride,Padding);
disp(Col)
     1     2     3     7     8     9    13    14    15
     2     3     4     8     9    10    14    15    16
     3     4     5     9    10    11    15    16    17
     4     5     6    10    11    12    16    17    18
     7     8     9    13    14    15    19    20    21
     8     9    10    14    15    16    20    21    22
     9    10    11    15    16    17    21    22    23
    10    11    12    16    17    18    22    23    24
    13    14    15    19    20    21    25    26    27
    14    15    16    20    21    22    26    27    28
    15    16    17    21    22    23    27    28    29
    16    17    18    22    23    24    28    29    30
    19    20    21    25    26    27    31    32    33
    20    21    22    26    27    28    32    33    34
    21    22    23    27    28    29    33    34    35
    22    23    24    28    29    30    34    35    36
一方,フィルタの方は,次のように縦ベクトルに配列しなおす.
F = reshape(Filter,[9 1]);
結果は
disp(Conv_matrix)
    40
    45
    50
    55
    70
    75
    80
    85
   100
   105
   110
   115
   130
   135
   140
   145
あとは,これをもとの画像のサイズに併せて配列しなおせば,畳み込みの計算が終了する.
Conv_image = reshape(Conv_matrix,[4 4])';
disp(Conv_image);
    40    45    50    55
    70    75    80    85
   100   105   110   115
   130   135   140   145
実際,Cnv2 でも計算をしておく.
Bias = 0; Stride =1; Padding = 0;
Z = Cnv2(Image,Filter,Bias,Stride,Padding);
sq = squeeze(Z(1,1,:,:));
disp(sq)
    40    45    50    55
    70    75    80    85
   100   105   110   115
   130   135   140   145
  
2. Col2Im とはどのようなものか
 
Col2Im は誤差逆伝播法で,逆伝播の部分の計算を軽く行うためのものである.
 
2.1  Col2Im とは
すでに設定した簡単な具体例 Image で説明する.
ここでは 3x3 のフィルタで畳み込み積を作ることを想定して,ブロックサイズは [3, 3] となっている.
Block_size = [Filter_height Filter_width];
bs1 = Block_size(1); bs2= Block_size(2);
disp(['size(Block_size) = [',num2str(Block_size),']']);
Stride = 1, Padding =0 と設定して話を進める.
Padding ≥ 0 の場合も含めて Image をゼロパディングしておく.Padding = 0 の場合は,Im_pad = Image である.
Im_pad = padarray(Image,[0,0,Padding,Padding],0,'both');
squeeze(Im_pad(1,1,:,:))
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
さて,ここで畳み込みをする際に使う画像データの各画素の重複を数えておこう.下図を参考に説明する.
左の図は 3x3 のフィルタで畳み込みをしたときの 2 に関わる部分である.右に 1ステップずらしたフィルタとの積に関係するため,2回重複して計算されていることがわかる.たとえば 8 については,右の図のように 4 回重複して計算されている.したがって,誤差逆伝播法で戻ってきたときに,戻ってきた結果が,それぞれの重複回数だけ足し合わされる.
Im2Colで変換したデータは次のものである.
col = Im2Col(Im_pad,Block_size,Stride,Padding);
disp(col)
     1     2     3     7     8     9    13    14    15
     2     3     4     8     9    10    14    15    16
     3     4     5     9    10    11    15    16    17
     4     5     6    10    11    12    16    17    18
     7     8     9    13    14    15    19    20    21
     8     9    10    14    15    16    20    21    22
     9    10    11    15    16    17    21    22    23
    10    11    12    16    17    18    22    23    24
    13    14    15    19    20    21    25    26    27
    14    15    16    20    21    22    26    27    28
    15    16    17    21    22    23    27    28    29
    16    17    18    22    23    24    28    29    30
    19    20    21    25    26    27    31    32    33
    20    21    22    26    27    28    32    33    34
    21    22    23    27    28    29    33    34    35
    22    23    24    28    29    30    34    35    36
ここで Im2Col をしたときに重複が起こる.たとえば次を参照.
2 は 2回,8は4回,16は9回出てきている.これらの重複度も込めて,この設定で,これを再び 6x 6 画像に戻す変換が Col2Im である.詳しくは後で述べるが,まずは結果のみ記しておく.これは次のコマンド Col2Im (Appendix 3)で実行される:
img = Col2Im(col,Image,Block_size,Stride,Padding);
squeeze(img(1,1,:,:))
     1     4     9    12    10     6
    14    32    54    60    44    24
    39    84   135   144   102    54
    57   120   189   198   138    72
    50   104   162   168   116    60
    31    64    99   102    70    36
たとえば 2 は元々の画像では (1,2) 成分のところにあったが,ここには 2*(重複回数) = 2*2 = 4 が入っている.16は (3, 4) 成分のところにあったが,ここには 16*(重複回数) = 16*9 = 144 が入っている.
  
2.2 重複の仕方を見るため,次の単純なテスト画像の場合を見ておこう.
Test_Image = zeros(1,1,6,6);
Test_Image(1,1,:,:) = ones(6,6);
squeeze(Test_Image(1,1,:,:))
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
     1     1     1     1     1     1
これに対して Col2Im を作用させると各座標での重複の回数がわかる.
Col = Im2Col(Test_Image,Block_size,Stride,Padding);
img = Col2Im(Col,Test_Image,Block_size,Stride,Padding);
squeeze(img(1,1,:,:))
     1     2     3     3     2     1
     2     4     6     6     4     2
     3     6     9     9     6     3
     3     6     9     9     6     3
     2     4     6     6     4     2
     1     2     3     3     2     1
  
2.3 Col2Im の計算手順
 
2.3.1 Stride = 1,Padding = 0 の場合
具体的に Col2Im がどのような手順で計算しているかを見ていく.
まず設定から.画像データは前と同じく Image とする.
計算は畳み込みのタイプによって変わる.(ただし今はとりあえず valid 型を考えているので,Padding = 0 である.)
Block_size = [Filter_height, Filter_width];
Im_pad = padarray(Image,[0,0,Padding,Padding],0,'both');
入力の元になる col を作る.
Col = Im2Col(Im_pad,Block_size,Stride,Padding);
disp(['size(Col) = [',num2str(size(Col)),']'])
Im_pad を Col にしたものは次のようになっている.
disp(Col)
     1     2     3     7     8     9    13    14    15
     2     3     4     8     9    10    14    15    16
     3     4     5     9    10    11    15    16    17
     4     5     6    10    11    12    16    17    18
     7     8     9    13    14    15    19    20    21
     8     9    10    14    15    16    20    21    22
     9    10    11    15    16    17    21    22    23
    10    11    12    16    17    18    22    23    24
    13    14    15    19    20    21    25    26    27
    14    15    16    20    21    22    26    27    28
    15    16    17    21    22    23    27    28    29
    16    17    18    22    23    24    28    29    30
    19    20    21    25    26    27    31    32    33
    20    21    22    26    27    28    32    33    34
    21    22    23    27    28    29    33    34    35
    22    23    24    28    29    30    34    35    36
この Col (すなわち Im2Col の)出力のサイズの内訳をみておく.
まず valid 型の畳み込みの出力のサイズの計算をする
Block_height = Block_size(1);
Block_width  = Block_size(2);
[NumImage,Channel_Image,image_height,image_width] = size(Image);
Output_height = fix((image_height - Block_height + 2*Padding)/Stride)+1;
Output_width = fix((image_width - Block_width + 2*Padding)/Stride)+1;
disp(['[Output_height, Output_width] = [',num2str(Output_height),',',num2str(Output_width),']']);
[Output_height, Output_width] = [4,4]
 
したがって Col のサイズは入力画像データ Im_pad を Im2Col で変形したものだから,その内訳は
(「Output_height」*「Output_width」*「NumImage」)  x
                     x (「Block_height」*「Block_width」*「Channel_Image」)
となっている.
注:このことから,Col2Im の入力データの型は上記のものと設定される.
入力画像の型を転置して
(「Block_height」*「Block_width」*「Channel_Image」)   x  
                   x (「Output_height」*「Output_width」*「NumImage」) 
としておく.すなわち
Col = permute(Col, [2 1]); 
% disp(['size(Col) = [',num2str(size(Col)),']']);
disp(Col)
     1     2     3     4     7     8     9    10    13    14    15    16    19    20    21    22
     2     3     4     5     8     9    10    11    14    15    16    17    20    21    22    23
     3     4     5     6     9    10    11    12    15    16    17    18    21    22    23    24
     7     8     9    10    13    14    15    16    19    20    21    22    25    26    27    28
     8     9    10    11    14    15    16    17    20    21    22    23    26    27    28    29
     9    10    11    12    15    16    17    18    21    22    23    24    27    28    29    30
    13    14    15    16    19    20    21    22    25    26    27    28    31    32    33    34
    14    15    16    17    20    21    22    23    26    27    28    29    32    33    34    35
    15    16    17    18    21    22    23    24    27    28    29    30    33    34    35    36
ここで Col のサイズは
 (「Block_height * Block_width」*「Channel_Image」)  x  (「Output_height」 * 「Output_width」*「NumImage」)
となったが,これに reshape を使えば,
「Block_height * Block_width」 x 「Channel_Image」 x 「Output_height」x「Output_width」 x 「NumImage」
に分解できる.すなわち
Col = reshape(Col,[Block_height*Block_width Channel_Image Output_width Output_height NumImage]);
    sq = squeeze(Col(i,1,:,:));
    disp(['Col(',num2str(i),',1,:,:) = ']);
end
     1     7    13    19
     2     8    14    20
     3     9    15    21
     4    10    16    22
     2     8    14    20
     3     9    15    21
     4    10    16    22
     5    11    17    23
     3     9    15    21
     4    10    16    22
     5    11    17    23
     6    12    18    24
     7    13    19    25
     8    14    20    26
     9    15    21    27
    10    16    22    28
     8    14    20    26
     9    15    21    27
    10    16    22    28
    11    17    23    29
     9    15    21    27
    10    16    22    28
    11    17    23    29
    12    18    24    30
    13    19    25    31
    14    20    26    32
    15    21    27    33
    16    22    28    34
    14    20    26    32
    15    21    27    33
    16    22    28    34
    17    23    29    35
    15    21    27    33
    16    22    28    34
    17    23    29    35
    18    24    30    36
さらに,後の計算のために,それを次のように permute する.
「NumImage」 x 「Block_height * Block_width」 x 「Channel_Image」 x 「Output_height」x「Output_width」
特に MATLAB は縦方向にデータを読んでいることにも注意.
Col = permute(Col,[5 1 2 4 3]); 
%% これにより[NumImage  Block_height*Block_width  Channel  Output_height  Output_width]
disp(['size(Col) = [',num2str(size(Col)),']']);
    sq = squeeze(Col(1,i,1,:,:));
    disp(['Col(1,',num2str(i),',:,:) = ']);
end
     1     2     3     4
     7     8     9    10
    13    14    15    16
    19    20    21    22
     2     3     4     5
     8     9    10    11
    14    15    16    17
    20    21    22    23
     3     4     5     6
     9    10    11    12
    15    16    17    18
    21    22    23    24
     7     8     9    10
    13    14    15    16
    19    20    21    22
    25    26    27    28
     8     9    10    11
    14    15    16    17
    20    21    22    23
    26    27    28    29
     9    10    11    12
    15    16    17    18
    21    22    23    24
    27    28    29    30
    13    14    15    16
    19    20    21    22
    25    26    27    28
    31    32    33    34
    14    15    16    17
    20    21    22    23
    26    27    28    29
    32    33    34    35
    15    16    17    18
    21    22    23    24
    27    28    29    30
    33    34    35    36
これを一つずつずらして加算していけばよいのであるが,このための計算結果を格納する枠組みを作っておく.
image_scheme = zeros(NumImage,Channel_Image,image_height+2*Padding+Stride-1, ...
    image_width+2*Padding+Stride-1);
disp(['size(image_scheme) = [',num2str(size(image_scheme)),']']);
size(image_scheme) = [1  1  6  6]
次の操作が本質的な部分である.(上記のブロックを一つずつずらして加算)
        coll = zeros(NumImage,Channel_Image,Output_height,Output_width);
        coll(:,:,:,:) = Col(:,(h-1)*Block_width+w,:,:,:);
        image_scheme(:,:, h:Stride:(h-1)+Output_height*Stride, w:Stride:(w-1)+Output_width*Stride) = image_scheme(:,:, ...
            h:Stride:(h-1)+Output_height*Stride, w:Stride:(w-1)+Output_width*Stride)+coll(:,:,:,:);
        squeeze(image_scheme(1,1,:,:))
end
     1     2     3     4     0     0
     7     8     9    10     0     0
    13    14    15    16     0     0
    19    20    21    22     0     0
     0     0     0     0     0     0
     0     0     0     0     0     0
     1     4     6     8     5     0
     7    16    18    20    11     0
    13    28    30    32    17     0
    19    40    42    44    23     0
     0     0     0     0     0     0
     0     0     0     0     0     0
     1     4     9    12    10     6
     7    16    27    30    22    12
    13    28    45    48    34    18
    19    40    63    66    46    24
     0     0     0     0     0     0
     0     0     0     0     0     0
     1     4     9    12    10     6
    14    24    36    40    22    12
    26    42    60    64    34    18
    38    60    84    88    46    24
    25    26    27    28     0     0
     0     0     0     0     0     0
     1     4     9    12    10     6
    14    32    45    50    33    12
    26    56    75    80    51    18
    38    80   105   110    69    24
    25    52    54    56    29     0
     0     0     0     0     0     0
     1     4     9    12    10     6
    14    32    54    60    44    24
    26    56    90    96    68    36
    38    80   126   132    92    48
    25    52    81    84    58    30
     0     0     0     0     0     0
     1     4     9    12    10     6
    14    32    54    60    44    24
    39    70   105   112    68    36
    57   100   147   154    92    48
    50    78   108   112    58    30
    31    32    33    34     0     0
     1     4     9    12    10     6
    14    32    54    60    44    24
    39    84   120   128    85    36
    57   120   168   176   115    48
    50   104   135   140    87    30
    31    64    66    68    35     0
     1     4     9    12    10     6
    14    32    54    60    44    24
    39    84   135   144   102    54
    57   120   189   198   138    72
    50   104   162   168   116    60
    31    64    99   102    70    36
加算し終わった最後の計算結果を,Padding してある場合は,Paddingの部分を削除する.(この例では,Padding = 0 の設定なので切り取りはしないことになる.)
Img = image_scheme(:,:,Padding+1:image_height+Padding,Padding+1:image_width+Padding);
disp(['size(Img) = [', num2str(size(Img)),']']);
disp(squeeze(Img(1,1,:,:)))
     1     4     9    12    10     6
    14    32    54    60    44    24
    39    84   135   144   102    54
    57   120   189   198   138    72
    50   104   162   168   116    60
    31    64    99   102    70    36
 
2.3.2  Stride = 1, Padding =1 の場合
Image を使って計算を見ていく.
Block_size = [Filter_height, Filter_width];
今度は Paddeing = 1 なので Im_pad はゼロパディングで拡張されている:
Im_pad = padarray(Image,[0,0,Padding,Padding],0,'both');
squeeze(Im_pad(1,1,:,:))
     0     0     0     0     0     0     0     0
     0     1     2     3     4     5     6     0
     0     7     8     9    10    11    12     0
     0    13    14    15    16    17    18     0
     0    19    20    21    22    23    24     0
     0    25    26    27    28    29    30     0
     0    31    32    33    34    35    36     0
     0     0     0     0     0     0     0     0
入力の元になる col を作る.
Col = Im2Col(Image,Block_size,Stride,Padding);
disp(['size(Col) = [', num2str(size(Col)),']'])
disp(Col)
     0     0     0     0     1     2     0     7     8
     0     0     0     1     2     3     7     8     9
     0     0     0     2     3     4     8     9    10
     0     0     0     3     4     5     9    10    11
     0     0     0     4     5     6    10    11    12
     0     0     0     5     6     0    11    12     0
     0     1     2     0     7     8     0    13    14
     1     2     3     7     8     9    13    14    15
     2     3     4     8     9    10    14    15    16
     3     4     5     9    10    11    15    16    17
     4     5     6    10    11    12    16    17    18
     5     6     0    11    12     0    17    18     0
     0     7     8     0    13    14     0    19    20
     7     8     9    13    14    15    19    20    21
     8     9    10    14    15    16    20    21    22
     9    10    11    15    16    17    21    22    23
    10    11    12    16    17    18    22    23    24
    11    12     0    17    18     0    23    24     0
     0    13    14     0    19    20     0    25    26
    13    14    15    19    20    21    25    26    27
    14    15    16    20    21    22    26    27    28
    15    16    17    21    22    23    27    28    29
    16    17    18    22    23    24    28    29    30
    17    18     0    23    24     0    29    30     0
     0    19    20     0    25    26     0    31    32
    19    20    21    25    26    27    31    32    33
    20    21    22    26    27    28    32    33    34
    21    22    23    27    28    29    33    34    35
    22    23    24    28    29    30    34    35    36
    23    24     0    29    30     0    35    36     0
     0    25    26     0    31    32     0     0     0
    25    26    27    31    32    33     0     0     0
    26    27    28    32    33    34     0     0     0
    27    28    29    33    34    35     0     0     0
    28    29    30    34    35    36     0     0     0
    29    30     0    35    36     0     0     0     0
% Col は変形していくので,オリジナルはほぞんしておく(後で使う).
Image の畳み込みの出力のサイズの計算をする.Padding = 1 なので,結果は same 型のものになる.
Block_height = Block_size(1);
Block_width  = Block_size(2);
[NumImage,Channel_Image,image_height,image_width] = size(Image);
Output_height = fix((image_height - Block_height + 2*Padding)/Stride)+1;
Output_width = fix((image_width - Block_width + 2*Padding)/Stride)+1;
disp(['[Output_height, Output_width] = [',num2str(Output_height),',',num2str(Output_width),']']);
[Output_height, Output_width] = [6,6]
Col のタイプは
(「NumImage」*「Output_height」*「Output_width」) x (「Channel 」*「 Block_height」*「Block_width」)
である.
入力画像の型を転置して
  (「Channel 」*「 Block_height」*「Block_width」) x (「NumImage」*「Output_height」*「Output_width」)
に変更する:
Col = permute(Col, [2 1]); 
disp(['size(Col) = [',num2str(size(Col)),']']);
これを既述の Padding = 0 のときのように reshape する.
Col = reshape(Col,[Block_height*Block_width Channel_Image Output_width Output_height NumImage]);
Col = permute(Col,[5 1 2 4 3]); 
% [NumImage  Block_height*Block_width  Channel_Image  Output_height  Output_width]
disp(['size(Col) = [',num2str(size(Col)),']']);
今度は Padding = 1 なので結果は次のようになっている.
    sq = squeeze(Col(1,i,1,:,:));
    disp(['Col(1,',num2str(i),',1,:,:) = ']);
end
     0     0     0     0     0     0
     0     1     2     3     4     5
     0     7     8     9    10    11
     0    13    14    15    16    17
     0    19    20    21    22    23
     0    25    26    27    28    29
     0     0     0     0     0     0
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
     0     0     0     0     0     0
     2     3     4     5     6     0
     8     9    10    11    12     0
    14    15    16    17    18     0
    20    21    22    23    24     0
    26    27    28    29    30     0
     0     1     2     3     4     5
     0     7     8     9    10    11
     0    13    14    15    16    17
     0    19    20    21    22    23
     0    25    26    27    28    29
     0    31    32    33    34    35
     1     2     3     4     5     6
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
     2     3     4     5     6     0
     8     9    10    11    12     0
    14    15    16    17    18     0
    20    21    22    23    24     0
    26    27    28    29    30     0
    32    33    34    35    36     0
     0     7     8     9    10    11
     0    13    14    15    16    17
     0    19    20    21    22    23
     0    25    26    27    28    29
     0    31    32    33    34    35
     0     0     0     0     0     0
     7     8     9    10    11    12
    13    14    15    16    17    18
    19    20    21    22    23    24
    25    26    27    28    29    30
    31    32    33    34    35    36
     0     0     0     0     0     0
     8     9    10    11    12     0
    14    15    16    17    18     0
    20    21    22    23    24     0
    26    27    28    29    30     0
    32    33    34    35    36     0
     0     0     0     0     0     0
後はPadding = 0 の場合と同様に計算をしていく.まず計算のための枠組みを作っておく.
image_scheme = zeros(NumImage,Channel_Image,image_height+2*Padding+Stride-1, ...
    image_width+2*Padding+Stride-1);
ずらして上記のブロックを足し合わせて行く部分:
        coll = zeros(NumImage,Channel_Image,Output_height,Output_width);
        coll(:,:,:,:) = Col(:,(h-1)*Block_width+w,:,:,:);
        image_scheme(:,:, h:Stride:(h-1)+Output_height*Stride, w:Stride:(w-1)+Output_width*Stride) = image_scheme(:,:, ...
            h:Stride:(h-1)+Output_height*Stride, w:Stride:(w-1)+Output_width*Stride)+coll(:,:,:,:);
        squeeze(image_scheme(1,1,:,:))
end
     0     0     0     0     0     0     0     0
     0     1     2     3     4     5     0     0
     0     7     8     9    10    11     0     0
     0    13    14    15    16    17     0     0
     0    19    20    21    22    23     0     0
     0    25    26    27    28    29     0     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     2     4     6     8    10     6     0
     0    14    16    18    20    22    12     0
     0    26    28    30    32    34    18     0
     0    38    40    42    44    46    24     0
     0    50    52    54    56    58    30     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     2     6     9    12    15    12     0
     0    14    24    27    30    33    24     0
     0    26    42    45    48    51    36     0
     0    38    60    63    66    69    48     0
     0    50    78    81    84    87    60     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     3     8    12    16    20    12     0
     0    21    32    36    40    44    24     0
     0    39    56    60    64    68    36     0
     0    57    80    84    88    92    48     0
     0    75   104   108   112   116    60     0
     0    31    32    33    34    35     0     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     4    10    15    20    25    18     0
     0    28    40    45    50    55    36     0
     0    52    70    75    80    85    54     0
     0    76   100   105   110   115    72     0
     0   100   130   135   140   145    90     0
     0    62    64    66    68    70    36     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     4    12    18    24    30    24     0
     0    28    48    54    60    66    48     0
     0    52    84    90    96   102    72     0
     0    76   120   126   132   138    96     0
     0   100   156   162   168   174   120     0
     0    62    96    99   102   105    72     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     4    12    18    24    30    24     0
     0    35    56    63    70    77    48     0
     0    65    98   105   112   119    72     0
     0    95   140   147   154   161    96     0
     0   125   182   189   196   203   120     0
     0    93   128   132   136   140    72     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     4    12    18    24    30    24     0
     0    42    64    72    80    88    60     0
     0    78   112   120   128   136    90     0
     0   114   160   168   176   184   120     0
     0   150   208   216   224   232   150     0
     0   124   160   165   170   175   108     0
     0     0     0     0     0     0     0     0
     0     0     0     0     0     0     0     0
     0     4    12    18    24    30    24     0
     0    42    72    81    90    99    72     0
     0    78   126   135   144   153   108     0
     0   114   180   189   198   207   144     0
     0   150   234   243   252   261   180     0
     0   124   192   198   204   210   144     0
     0     0     0     0     0     0     0     0
Padding してある場合は,その部分を削除する.
Img = image_scheme(:,:,Padding+1:image_height+Padding,Padding+1:image_width+Padding);
disp(squeeze(Img(1,1,:,:)));
     4    12    18    24    30    24
    42    72    81    90    99    72
    78   126   135   144   153   108
   114   180   189   198   207   144
   150   234   243   252   261   180
   124   192   198   204   210   144
コマンドをまとめた関数が Col2Im (Appendix 3 参照)である.Col_org を Col2Im で変換すると,当然上記と同じ結果になる:
ImgCol = Col2Im(Col_org,Img,Block_size,Stride, Padding);
disp(['[size(ImgCol) = [',num2str(size(ImgCol)),']']);
[size(ImgCol) = [1  1  6  6]
squeeze(ImgCol(1,1,:,:))
     4    12    18    24    30    24
    42    72    81    90    99    72
    78   126   135   144   153   108
   114   180   189   198   207   144
   150   234   243   252   261   180
   124   192   198   204   210   144
Col2Im (Appendix 3 参照)で,Test_Image を計算してみる.
% 画像 Test_Image に Col2Im (Stride = 1, Padding =0) を作用させた結果
Col = Im2Col(Test_Image,Block_size,Stride,Padding);
Image_Col = Col2Im(Col,Test_Image,Block_size,Stride,Padding);
squeeze(Image_Col(1,1,:,:))
     1     2     3     3     2     1
     2     4     6     6     4     2
     3     6     9     9     6     3
     3     6     9     9     6     3
     2     4     6     6     4     2
     1     2     3     3     2     1
% 画像 Test_Image に Col2Im (Stride = 1, Padding =1) を作用させた結果
Col = Im2Col(Test_Image,Block_size,Stride,Padding);
Image_Col = Col2Im(Col,Test_Image,Block_size,Stride,Padding);
squeeze(Image_Col(1,1,:,:))
     4     6     6     6     6     4
     6     9     9     9     9     6
     6     9     9     9     9     6
     6     9     9     9     9     6
     6     9     9     9     9     6
     4     6     6     6     6     4
  
3. Unpooling - 最大プーリングの逆のアルゴリズム(全結合層 → 畳み込み層)
  
順伝播で最大プーリングしたが,そこでデシメートした数は 0 に置き換えてアップサンプリングする.すなわち
全結合層 → 最大プーリングしたものを戻す(アップサンプリング)
をしておく必要がある.
そこの部分の計算を簡単な例で説明する.
まず 2 x 2 の max pooling に由来する最初の全結合層の Delta を次のものとする.
max pooling で残した位置もこのデータに付加する.なお以下の位置データは仮想的なものである.
PI = [4,3,1,2,3,1,4,2,3,4,1,2,3,1,4,2,3,4,1,2,1,3,2,4,4,3,1,2,4,3,2,1,3,4,2,1];
% PI は次のようにランダムに作成したものである.
%     A = randperm(4); %2x2 max pooling なので,2*2=4 区切りで 1~4の(ランダムな)配列を作る.
%     PI(1,4*(i-1)+1:4*(i-1)+4)=A;
この最大値の位置データ 1~4 を,長さ 4 の横配列の 1 の位置で記憶させておく.
disp(col_D);
     0     0     0     1
     0     0     2     0
     3     0     0     0
     0     4     0     0
     0     0     5     0
     6     0     0     0
     0     0     0     7
     0     8     0     0
     0     0     9     0
     0     0     0    10
    11     0     0     0
     0    12     0     0
     0     0    13     0
    14     0     0     0
     0     0     0    15
     0    16     0     0
     0     0    17     0
     0     0     0    18
    19     0     0     0
     0    20     0     0
    21     0     0     0
     0     0    22     0
     0    23     0     0
     0     0     0    24
     0     0     0    25
     0     0    26     0
    27     0     0     0
     0    28     0     0
     0     0     0    29
     0     0    30     0
     0    31     0     0
    32     0     0     0
     0     0    33     0
     0     0     0    34
     0    35     0     0
    36     0     0     0
これは,最初の 1 は もともとは 2 x 2 のブロックの (1, 2)  の位置にあり,次の2は二つ目の 2x2 ブロックの (1,0) の位置,以下同様の位置にあったことを意味している.これをもとの位置に配列しなおすのが Col2Im である.全部で 36*4 の成分があるものを sqrt(36*4) = 12  の正方配列にしなおすので,出力は 12x12 である.
また,今回の場合は,Block_size =[2, 2] で Stride =2 なので,Col2Im において重複加算は起こらないことに注意.
Image_type = zeros(1,1,12,12);
X = Col2Im(col_D,Image_type,[2,2],2,0);
disp(['size(X) = [', num2str(size(X)),']']);
disp(squeeze(X(1,1,:,:)))
     0     0     0     0     3     0     0     4     0     0     6     0
     0     1     2     0     0     0     0     0     5     0     0     0
     0     0     0     8     0     0     0     0    11     0     0    12
     0     7     0     0     9     0     0    10     0     0     0     0
     0     0    14     0     0     0     0    16     0     0     0     0
    13     0     0     0     0    15     0     0    17     0     0    18
    19     0     0    20    21     0     0     0     0    23     0     0
     0     0     0     0     0     0    22     0     0     0     0    24
     0     0     0     0    27     0     0    28     0     0     0     0
     0    25    26     0     0     0     0     0     0    29    30     0
     0    31    32     0     0     0     0     0     0    35    36     0
     0     0     0     0    33     0     0    34     0     0     0     0
これで所望の配列が得られていることがわかる.
これをまとめたのが Max_unpooling である:
Image_type = zeros(1,1,12,12);
X = Max_unpooling(Image_type,PI,Delta,block,Stride);
disp(squeeze(X(1,1,:,:)));
     0     0     0     0     3     0     0     4     0     0     6     0
     0     1     2     0     0     0     0     0     5     0     0     0
     0     0     0     8     0     0     0     0    11     0     0    12
     0     7     0     0     9     0     0    10     0     0     0     0
     0     0    14     0     0     0     0    16     0     0     0     0
    13     0     0     0     0    15     0     0    17     0     0    18
    19     0     0    20    21     0     0     0     0    23     0     0
     0     0     0     0     0     0    22     0     0     0     0    24
     0     0     0     0    27     0     0    28     0     0     0     0
     0    25    26     0     0     0     0     0     0    29    30     0
     0    31    32     0     0     0     0     0     0    35    36     0
     0     0     0     0    33     0     0    34     0     0     0     0
Appendix 1. im2col の関数ファイル
function Col = Im2Col(Image,block,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Filter, Channel, Image_height, Image_width)
% block = [number1, number2]
% Stride = number, Padding = number
% (NumImage*Output_height * Output_width) x (Channel * Block_height*Block_width)
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 次の [A], [B], [C], [D] の Pythonプログラムにもとづく.
% [A] https://docs.chainer.org/en/v7.8.1.post1/reference/generated/chainer.functions.im2col.html
% [B] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
% [C] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
% [D] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, Channel, Image_height, Image_width] = size(Image);
Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1;
Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1;
col = zeros(NumImage,Block_height*Block_width,Channel,Output_hight,Output_width);
Image = padarray(Image,[0 0 Padding Padding],0,'both');
        % MATLAB の im2col の配列にするには次のようにする.
        %col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:,w:Stride:w-1+WS ,h:Stride:h-1+HS);
        % この部分は,Python などの他の文献のようにするには次のようにする.
         col(:,(h-1)*Block_width+w,:,:,:)=Image(:,:, h:Stride:h-1+HS, w:Stride:w-1+WS);
% MATLAB式 im2col にするには次のようにする:
% Col = permute(col,[2 3 4 5 1]);
% この部分は,Python などの他の文献のようにするには次のようにする.
 Col = permute(col,[2 3 5 4 1]);
Col = reshape(Col, [Channel*Block_height*Block_width NumImage*Output_hight*Output_width ]);
Col = permute(Col,[2 1]);
  
Appendix 2.Cnv2 の関数ファイル
function Z = Cnv2(Image,Filter,Bias,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Image = Image(Number of Images, Channel, Image_height, Image_width)
% Filter = Filter(Number of Filters, Channel, Filter_height, Filter_width)
% Number of Images x Number of Filters x output_hight x output_width
% Output_hight = fix((Image_height-Block_height+2*Padding)/Stride)+1
% Output_width = fix((Image_width-Block_width+2*Padding)/Stride)+1
% 本ノートの参考文献 [2], [3], [4] の Pythonプログラムにもとづく.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[NumImage, ~, Image_height, Image_width] = size(Image);
[NumFilter,Channel,Filter_height,Filter_width] = size(Filter);
Output_height = fix((Image_height -Filter_height +2*Padding)/Stride)+1;
Output_width = fix((Image_width -Filter_width +2*Padding)/Stride)+1;
Image = Im2Col(Image,[Filter_height,Filter_width],Stride,Padding);
% Im2Col が Python 仕様の配列の場合
Filter = permute(Filter,[4 3 2 1]);
% Im2Col が MATLAB 仕様の配列方式の場合
%Filter = permute(Filter,[3 4 2 1]);
Filter = reshape(Filter,[Channel*Filter_height*Filter_width NumFilter])';
Filter = permute(Filter,[2 1]);
Z = affine_product(Image,Filter,Bias);
Z = reshape(Z,[NumFilter Output_width Output_height NumImage]);
 Z = permute(Z,[4 1 3 2]);
%Z = permute(Z,[4 1 2 3]);
    if mod(Filter_height,2) == 0
        Z = Z(:,:,2:Output_height,:);
    if mod(Filter_width,2) == 0
        Z = Z(:,:,:,2:Output_width);
Cnv2 で使う affine_product 関数の定義
function C = affine_product(A,B,b)
% 入力:画像 A,ファイルタ B,バイアス b = 実数
 
Appendix 3. Col2Im の関数ファイル
function Img = Col2Im(Col, Image_type, Block_size,Stride,Padding)
%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Col = Im2Col(Image,block,Stride,Padding) に対して,
% Img = Col2Im(Col,Image_type,block,Stride,Padding) となっていることを想定している.
% [NumImage, Channel, Image_height, Image_width] = size(Image)
% Image_type は [NumImage, Channel, Image_height, Image_width] 
% block はデータを区切りブロック化するサイズを指定.畳み込みのフィルタのサイズに相当.
% Stride はデータをずらす幅.畳み込みのずらし幅に相当.
% Padding は画像のゼロパディングの幅.畳み込みの出力サイズを定める幅に相当.
% [NumFilter,Channel,Image_height,Image_width] = size(Image_type);
% [1], [2], [3] の Pythonプログラムにもとづく.
%   [1] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
%   [2] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
%   [3] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Block_height = Block_size(1);
Block_width  = Block_size(2);
[NumFilter,Channel,Image_height,Image_width] = size(Image_type);
Output_height = fix((Image_height - Block_height + 2*Padding)/Stride)+1;
Output_width = fix((Image_width - Block_width + 2*Padding)/Stride)+1;
Col = permute(Col, [2 1]);
Col = reshape(Col,[Block_height*Block_width Channel Output_width Output_height NumFilter]);
Col = permute(Col,[5 1 2 4 3]); %[NumFilter Block_height*Block_width Channel Output_height Output_width]
im_scheme = zeros(NumFilter,Channel,Image_height+2*Padding+Stride-1, ...
    Image_width+2*Padding+Stride-1);
        coll = zeros(NumFilter,Channel,Output_height,Output_width);
        coll(:,:,:,:) = Col(:,(h-1)*Block_width+w,:,:,:);
        im_scheme(:,:, h:Stride:(h-1)+Output_height*Stride, w:Stride:(w-1)+Output_width*Stride) = im_scheme(:,:, ...
            h:Stride:(h-1)+Output_height*Stride, w:Stride:(w-1)+Output_width*Stride)+coll(:,:,:,:);
Img = im_scheme(:,:,Padding+1:Image_height+Padding,Padding+1:Image_width+Padding);
 
Appendix 4.  Max_unpooling の関数ファイル
function X = Max_unpooling(Image_type,PI,Delta,block,Stride)
%[NumImage,Channel_Images,Image_height,Image_width] = size(Image_type);
% 下記参考文献 [2] の Python プログラムに基づく.
D_flat = reshape_1(Delta);
    col_D(k,PI(k))=D_flat(k);
X=Col2Im(col_D,Image_type,block,Stride,0);
Max_unpooling で使う reshape_1 の定義
function Y = reshape_1(X)
m2 = permute(X,ndims(X):-1:1); % ndims(X) は X の次元数
m3 = reshape(m2,numel(X)/N,N); % numel(X) は X の要素数
Y = permute(m3,ndims(m3):-1:1);
  
参考文献
[1] 斎藤康毅,ゼロから作る Deep Learning,O'REILLY, 2016.
[2] 立石賢吾,やさしく学ぶディープラーニングがわかる数学のきほん,マイナビ,2019.
[3] 我妻幸長,はじめてのディープラーニング - Python で学ぶニューラルネットワークとバックプロパゲーション,SB Creative, 2018.
[4] 新井仁之「深層学習で使われる im2col を MATLAB で解説」(http://www.araiweb.matrix.jp/Program/Im2Col_tutorial2.html)
  
  
Copyright © Hitoshi Arai, 2025