00001
00007
00008
00009
00010
00011
00012
00013
00020 static unsigned int
00021 isqrt(unsigned long val) {
00022 unsigned long temp, g=0, b = 0x8000, bshft = 15;
00023 do {
00024 if (val >= (temp = (((g << 1) + b)<<bshft--))) {
00025 g += b;
00026 val -= temp;
00027 }
00028 } while (b >>= 1);
00029 return g;
00030 }
00031
00032
00038 static void
00039 vl_ikm_elkan_update_inter_dist (VlIKMFilt *f)
00040 {
00041 int
00042 i, k, kp,
00043 K = f-> K,
00044 M = f-> M ;
00045 vl_ikm_acc dist, delta ;
00046
00047
00048 for(k = 0 ; k < K ; ++ k) {
00049 for(kp = 0 ; kp < K ; ++ kp) {
00050 dist = 0 ;
00051 if (k != kp) {
00052 for(i = 0 ; i < M ; ++i) {
00053 delta = f->centers [kp*M + i] - f->centers [k*M + i] ;
00054 dist += delta * delta ;
00055 }
00056 }
00057 f->inter_dist [k*K + kp] = f->inter_dist [kp*K + k] = dist >> 2 ;
00058 }
00059 }
00060 }
00061
00067 static void
00068 vl_ikm_init_elkan (VlIKMFilt *f)
00069 {
00070 if (f-> inter_dist) {
00071 vl_free (f-> inter_dist) ;
00072 }
00073 f-> inter_dist = vl_malloc (sizeof(vl_ikm_acc) * f->K*f->K) ;
00074 vl_ikm_elkan_update_inter_dist (f) ;
00075 }
00076
00077
00085 static int
00086 vl_ikm_train_elkan (VlIKMFilt* f, vl_uint8 const* data, int N)
00087 {
00088
00089 int i,pass,c,cp,x,cx ;
00090 int dist_calc = 0 ;
00091 int
00092 K = f-> K,
00093 M = f-> M ;
00094
00095 vl_ikm_acc dist ;
00096 vl_ikm_acc *m_pt = vl_malloc(sizeof(vl_ikm_acc)* M*K) ;
00097 vl_ikm_acc *u_pt = vl_malloc(sizeof(vl_ikm_acc)* N) ;
00098 char *r_pt = vl_malloc(sizeof(char) * 1*N) ;
00099 vl_ikm_acc *s_pt = vl_malloc(sizeof(vl_ikm_acc)* K) ;
00100 vl_ikm_acc *l_pt = vl_malloc(sizeof(vl_ikm_acc)* N*K) ;
00101 vl_ikm_acc *d_pt = f-> inter_dist ;
00102 vl_uint *asgn = vl_malloc (sizeof(vl_uint) * N) ;
00103 vl_uint *counts=vl_malloc (sizeof(vl_uint) * N) ;
00104
00105 int done = 0 ;
00106
00107
00108 vl_ikm_elkan_update_inter_dist (f) ;
00109
00110
00111 memset(l_pt, 0, sizeof(vl_ikm_acc) * N*K ) ;
00112 memset(u_pt, 0, sizeof(vl_ikm_acc) * N ) ;
00113 memset(r_pt, 0, sizeof(char) * N ) ;
00114 for(x = 0 ; x < N ; ++x) {
00115 vl_ikm_acc best_dist ;
00116
00117
00118 dist_calc ++ ;
00119 for(dist = 0, i = 0 ; i < M ; ++i) {
00120 vl_ikm_acc delta = data[x*M + i] - f->centers[i] ;
00121 dist += delta*delta ;
00122 }
00123 cx = 0 ;
00124 best_dist = dist ;
00125 l_pt[x] = dist ;
00126
00127
00128 for(c = 1 ; c < K ; ++c) {
00129 if(d_pt[K*cx+c] < best_dist) {
00130
00131
00132 dist_calc++ ;
00133 for(dist=0, i = 0 ; i < M ; ++i) {
00134 vl_ikm_acc delta = data[x*M + i] - f->centers[c*M + i] ;
00135 dist += delta*delta ;
00136 }
00137
00138
00139 l_pt[N*c + x] = dist ;
00140
00141 if(dist < best_dist) {
00142 best_dist = dist ;
00143 cx = c ;
00144 }
00145 }
00146 }
00147
00148 asgn[x] = cx ;
00149 u_pt[x] = best_dist ;
00150 }
00151
00152
00153
00154
00155
00156 for (pass = 0 ; 1 ; ++ pass) {
00157
00158
00159
00160
00161 memset(m_pt, 0, sizeof(vl_ikm_acc) * M * K) ;
00162 memset(counts, 0, sizeof(vl_ikm_acc) * K) ;
00163
00164
00165 for(x = 0 ; x < N ; ++x) {
00166 int cx = asgn[x] ;
00167 ++ counts[ cx ] ;
00168 for(i = 0 ; i < M ; ++i) {
00169 m_pt[cx*M + i] += data[x*M + i] ;
00170 }
00171 }
00172
00173
00174 for(c = 0 ; c < K ; ++c) {
00175 vl_ikm_acc n = counts[c] ;
00176 if(n > 0) {
00177 for(i = 0 ; i < M ; ++i) {
00178 m_pt[c*M + i] /= n ;
00179 }
00180 } else {
00181 for(i = 0 ; i < M ; ++i) {
00182
00183 }
00184 }
00185 }
00186
00187
00188
00189
00190 for(c = 0 ; c < K ; ++c) {
00191
00192
00193 dist_calc++ ;
00194 for(dist = 0, i = 0 ; i < M ; ++i) {
00195 vl_ikm_acc delta = m_pt[c*M + i] - f->centers[c*M + i] ;
00196 f->centers[c*M + i] = m_pt[c*M +i] ;
00197 dist += delta*delta ;
00198 }
00199 for(x = 0 ; x < N ; ++x) {
00200 vl_ikm_acc lxc = l_pt[c*N + x] ;
00201 vl_uint cx = asgn[x] ;
00202
00203
00204 if(dist < lxc) {
00205 lxc = lxc + dist - 2*(isqrt(lxc)+1)*(isqrt(dist)+1) ;
00206 } else {
00207 lxc = 0 ;
00208 }
00209 l_pt[c*N + x] = lxc ;
00210
00211
00212 if(c == cx) {
00213 vl_ikm_acc ux = u_pt[x] ;
00214 u_pt[x] = ux + dist + 2 * (isqrt(ux)+1)*(isqrt(dist)+1);
00215 r_pt[x] = 1 ;
00216 }
00217 }
00218 }
00219
00220
00221 for(c = 0 ; c < K ; ++c) {
00222 for(cp = 0 ; cp < K ; ++cp) {
00223 dist = 0 ;
00224 if( c != cp ) {
00225 dist_calc++;
00226 for(i = 0 ; i < M ; ++i) {
00227 vl_ikm_acc delta = f->centers[ cp*M + i ] - f->centers[ c*M + i ] ;
00228 dist += delta*delta ;
00229 }
00230 }
00231 d_pt[c*K+cp] = d_pt[cp*K+c] = dist>>2 ;
00232 }
00233 }
00234
00235
00236 for(c = 0 ; c < K ; ++c) {
00237 vl_ikm_acc best_dist = VL_BIG_INT ;
00238 for(cp = 0 ; cp < K ; ++cp) {
00239 dist = d_pt[c*K+cp] ;
00240 if(c != cp && dist < best_dist) best_dist = dist ;
00241 }
00242 s_pt[c] = best_dist >> 2 ;
00243 }
00244
00245
00246
00247
00248
00249 done = 1 ;
00250 for(x=0 ; x < N ; ++x) {
00251 vl_uint cx = asgn[x] ;
00252 vl_ikm_acc ux = u_pt[x] ;
00253
00254
00255
00256
00257
00258
00259 if(ux <= s_pt[cx]) continue ;
00260
00261 for(c = 0 ; c < K ; ++c) {
00262 vl_ikm_acc dist = 0 ;
00263
00264
00265
00266
00267
00268
00269
00270 if(c == cx ||
00271 ux <= l_pt[N*c + x] ||
00272 ux <= d_pt[K*c + cx] )
00273 continue ;
00274
00275
00276
00277
00278
00279 if( r_pt[x] ) {
00280 dist_calc++;
00281 for(dist = 0, i = 0 ; i < M ; ++i) {
00282 vl_ikm_acc delta = data[ x*M + i ] - f->centers[ cx*M + i ] ;
00283 dist += delta*delta ;
00284 }
00285 ux = u_pt[x] = dist ;
00286 r_pt[x] = 0 ;
00287
00288
00289
00290 if(
00291 ux <= l_pt[N*c + x] ||
00292 ux <= d_pt[K*c + cx] )
00293 continue ;
00294 }
00295
00296
00297 dist_calc++ ;
00298 for(dist = 0, i = 0 ; i < M ; ++i) {
00299 vl_ikm_acc delta = data[ x*M + i ] - f->centers[ c*M + i ] ;
00300 dist += delta*delta ;
00301 }
00302
00303 l_pt[N*c + x] = dist ;
00304
00305 if( dist < ux ) {
00306 ux = u_pt[x] = dist ;
00307
00308 asgn[x] = c ;
00309 done = 0 ;
00310 }
00311 }
00312 }
00313
00314
00315 if(done || pass == f->max_niters) {
00316 break ;
00317 }
00318 }
00319
00320 vl_free (counts) ;
00321 vl_free (asgn) ;
00322 vl_free (l_pt) ;
00323 vl_free (s_pt) ;
00324 vl_free (r_pt) ;
00325 vl_free (u_pt) ;
00326 vl_free (m_pt) ;
00327
00328 if (f-> verb) {
00329 VL_PRINTF ("ikm: Elkan algorithm: total iterations: %d\n", pass) ;
00330 VL_PRINTF ("ikm: Elkan algorithm: distance calculations: %d (speedup: %.2f)\n",
00331 dist_calc, (float)N*K*(pass+2) / dist_calc -1) ;
00332 }
00333 return 0 ;
00334 }
00335
00344 static void
00345 vl_ikm_push_elkan (VlIKMFilt *f, vl_uint *asgn, vl_uint8 const *data, int N)
00346 {
00347 vl_uint i,c,cx,x,
00348 dist_calc = 0,
00349 K = f-> K,
00350 M = f-> M ;
00351 vl_ikm_acc dist, best_dist ;
00352 vl_ikm_acc *d_pt = f-> inter_dist ;
00353
00354
00355 for(x=0 ; x < N ; ++x) {
00356 best_dist = VL_BIG_INT ;
00357 cx = 0 ;
00358
00359 for(c = 0 ; c < K ; ++c) {
00360 if(d_pt[K*cx+c] < best_dist) {
00361
00362 dist_calc ++ ;
00363 for(dist=0, i = 0 ; i < M ; ++i) {
00364 vl_ikm_acc delta = data[x*M + i] - f->centers[c*M + i] ;
00365 dist += delta*delta ;
00366 }
00367
00368
00369 if(dist < best_dist) {
00370 best_dist = dist ;
00371 cx = c ;
00372 }
00373 }
00374 }
00375 asgn [x] = cx ;
00376 }
00377 }
00378
00379
00380
00381
00382
00383