bidir.cl 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. /* Mode selection routines, select the least SATD cost mode for each lowres
  2. * macroblock. When measuring B slices, this includes measuring the cost of
  3. * three bidir modes. */
  4. /* Four threads cooperatively measure 8x8 BIDIR cost with SATD */
  5. int bidir_satd_8x8_ii_coop4( read_only image2d_t fenc_lowres,
  6. int2 fencpos,
  7. read_only image2d_t fref0_planes,
  8. int2 qpos0,
  9. read_only image2d_t fref1_planes,
  10. int2 qpos1,
  11. int weight,
  12. local sum2_t *tmpp,
  13. int idx )
  14. {
  15. volatile local sum2_t( *tmp )[4] = (volatile local sum2_t( * )[4])tmpp;
  16. sum2_t b0, b1, b2, b3;
  17. sum2_t sum = 0;
  18. // fencpos is full-pel position of original MB
  19. // qpos0 is qpel position within reference frame 0
  20. // qpos1 is qpel position within reference frame 1
  21. int2 fref0Apos = (int2)(qpos0.x>>2, qpos0.y>>2);
  22. int hpel0A = ((qpos0.x&2)>>1) + (qpos0.y&2);
  23. int2 qpos0B = (int2)qpos0 + (int2)(((qpos0.x&1)<<1), ((qpos0.y&1)<<1));
  24. int2 fref0Bpos = (int2)(qpos0B.x>>2, qpos0B.y>>2);
  25. int hpel0B = ((qpos0B.x&2)>>1) + (qpos0B.y&2);
  26. int2 fref1Apos = (int2)(qpos1.x>>2, qpos1.y>>2);
  27. int hpel1A = ((qpos1.x&2)>>1) + (qpos1.y&2);
  28. int2 qpos1B = (int2)qpos1 + (int2)(((qpos1.x&1)<<1), ((qpos1.y&1)<<1));
  29. int2 fref1Bpos = (int2)(qpos1B.x>>2, qpos1B.y>>2);
  30. int hpel1B = ((qpos1B.x&2)>>1) + (qpos1B.y&2);
  31. uint mask_shift0A = 8 * hpel0A, mask_shift0B = 8 * hpel0B;
  32. uint mask_shift1A = 8 * hpel1A, mask_shift1B = 8 * hpel1B;
  33. uint vA, vB;
  34. uint enc, ref0, ref1;
  35. uint a0, a1;
  36. const int weight2 = 64 - weight;
  37. #define READ_BIDIR_DIFF( OUT, X )\
  38. enc = read_imageui( fenc_lowres, sampler, fencpos + (int2)(X, idx) ).s0;\
  39. vA = (read_imageui( fref0_planes, sampler, fref0Apos + (int2)(X, idx) ).s0 >> mask_shift0A) & 0xFF;\
  40. vB = (read_imageui( fref0_planes, sampler, fref0Bpos + (int2)(X, idx) ).s0 >> mask_shift0B) & 0xFF;\
  41. ref0 = rhadd( vA, vB );\
  42. vA = (read_imageui( fref1_planes, sampler, fref1Apos + (int2)(X, idx) ).s0 >> mask_shift1A) & 0xFF;\
  43. vB = (read_imageui( fref1_planes, sampler, fref1Bpos + (int2)(X, idx) ).s0 >> mask_shift1B) & 0xFF;\
  44. ref1 = rhadd( vA, vB );\
  45. OUT = enc - ((ref0 * weight + ref1 * weight2 + (1 << 5)) >> 6);
  46. #define READ_DIFF_EX( OUT, a, b )\
  47. READ_BIDIR_DIFF( a0, a );\
  48. READ_BIDIR_DIFF( a1, b );\
  49. OUT = a0 + (a1<<BITS_PER_SUM);
  50. #define ROW_8x4_SATD( a, b, c )\
  51. fencpos.y += a;\
  52. fref0Apos.y += b;\
  53. fref0Bpos.y += b;\
  54. fref1Apos.y += c;\
  55. fref1Bpos.y += c;\
  56. READ_DIFF_EX( b0, 0, 4 );\
  57. READ_DIFF_EX( b1, 1, 5 );\
  58. READ_DIFF_EX( b2, 2, 6 );\
  59. READ_DIFF_EX( b3, 3, 7 );\
  60. HADAMARD4( tmp[idx][0], tmp[idx][1], tmp[idx][2], tmp[idx][3], b0, b1, b2, b3 );\
  61. HADAMARD4( b0, b1, b2, b3, tmp[0][idx], tmp[1][idx], tmp[2][idx], tmp[3][idx] );\
  62. sum += abs2( b0 ) + abs2( b1 ) + abs2( b2 ) + abs2( b3 );
  63. ROW_8x4_SATD( 0, 0, 0 );
  64. ROW_8x4_SATD( 4, 4, 4 );
  65. #undef READ_BIDIR_DIFF
  66. #undef READ_DIFF_EX
  67. #undef ROW_8x4_SATD
  68. return (((sum_t)sum) + (sum>>BITS_PER_SUM)) >> 1;
  69. }
  70. /*
  71. * mode selection - pick the least cost partition type for each 8x8 macroblock.
  72. * Intra, list0 or list1. When measuring a B slice, also test three bidir
  73. * possibilities.
  74. *
  75. * fenc_lowres_mvs[0|1] and fenc_lowres_mv_costs[0|1] are large buffers that
  76. * hold many frames worth of motion vectors. We must offset into the correct
  77. * location for this frame's vectors:
  78. *
  79. * CPU equivalent: fenc->lowres_mvs[0][b - p0 - 1]
  80. * GPU equivalent: fenc_lowres_mvs0[(b - p0 - 1) * mb_count]
  81. *
  82. * global launch dimensions for P slice estimate: [mb_width, mb_height]
  83. * global launch dimensions for B slice estimate: [mb_width * 4, mb_height]
  84. */
  85. kernel void mode_selection( read_only image2d_t fenc_lowres,
  86. read_only image2d_t fref0_planes,
  87. read_only image2d_t fref1_planes,
  88. const global short2 *fenc_lowres_mvs0,
  89. const global short2 *fenc_lowres_mvs1,
  90. const global short2 *fref1_lowres_mvs0,
  91. const global int16_t *fenc_lowres_mv_costs0,
  92. const global int16_t *fenc_lowres_mv_costs1,
  93. const global uint16_t *fenc_intra_cost,
  94. global uint16_t *lowres_costs,
  95. global int *frame_stats,
  96. local int16_t *cost_local,
  97. local sum2_t *satd_local,
  98. int mb_width,
  99. int bipred_weight,
  100. int dist_scale_factor,
  101. int b,
  102. int p0,
  103. int p1,
  104. int lambda )
  105. {
  106. int mb_x = get_global_id( 0 );
  107. int b_bidir = b < p1;
  108. if( b_bidir )
  109. {
  110. /* when mode_selection is run for B frames, it must perform BIDIR SATD
  111. * measurements, so it is launched with four times as many threads in
  112. * order to spread the work around more of the GPU. And it can add
  113. * padding threads in the X direction. */
  114. mb_x >>= 2;
  115. if( mb_x >= mb_width )
  116. return;
  117. }
  118. int mb_y = get_global_id( 1 );
  119. int mb_height = get_global_size( 1 );
  120. int mb_count = mb_width * mb_height;
  121. int mb_xy = mb_x + mb_y * mb_width;
  122. /* Initialize int frame_stats[4] for next kernel (sum_inter_cost) */
  123. if( mb_x < 4 && mb_y == 0 )
  124. frame_stats[mb_x] = 0;
  125. int bcost = COST_MAX;
  126. int list_used = 0;
  127. if( !b_bidir )
  128. {
  129. int icost = fenc_intra_cost[mb_xy];
  130. COPY2_IF_LT( bcost, icost, list_used, 0 );
  131. }
  132. if( b != p0 )
  133. {
  134. int mv_cost0 = fenc_lowres_mv_costs0[(b - p0 - 1) * mb_count + mb_xy];
  135. COPY2_IF_LT( bcost, mv_cost0, list_used, 1 );
  136. }
  137. if( b != p1 )
  138. {
  139. int mv_cost1 = fenc_lowres_mv_costs1[(p1 - b - 1) * mb_count + mb_xy];
  140. COPY2_IF_LT( bcost, mv_cost1, list_used, 2 );
  141. }
  142. if( b_bidir )
  143. {
  144. int2 coord = (int2)(mb_x, mb_y) << 3;
  145. int mb_i = get_global_id( 0 ) & 3;
  146. int mb_in_group = get_local_id( 1 ) * (get_local_size( 0 ) >> 2) + (get_local_id( 0 ) >> 2);
  147. cost_local += mb_in_group * 4;
  148. satd_local += mb_in_group * 16;
  149. #define TRY_BIDIR( mv0, mv1, penalty )\
  150. {\
  151. int2 qpos0 = (int2)((coord.x<<2) + mv0.x, (coord.y<<2) + mv0.y);\
  152. int2 qpos1 = (int2)((coord.x<<2) + mv1.x, (coord.y<<2) + mv1.y);\
  153. cost_local[mb_i] = bidir_satd_8x8_ii_coop4( fenc_lowres, coord, fref0_planes, qpos0, fref1_planes, qpos1, bipred_weight, satd_local, mb_i );\
  154. int cost = cost_local[0] + cost_local[1] + cost_local[2] + cost_local[3];\
  155. COPY2_IF_LT( bcost, penalty * lambda + cost, list_used, 3 );\
  156. }
  157. /* temporal prediction */
  158. short2 dmv0, dmv1;
  159. short2 mvr = fref1_lowres_mvs0[mb_xy];
  160. dmv0 = (mvr * (short) dist_scale_factor + (short) 128) >> (short) 8;
  161. dmv1 = dmv0 - mvr;
  162. TRY_BIDIR( dmv0, dmv1, 0 )
  163. if( as_uint( dmv0 ) || as_uint( dmv1 ) )
  164. {
  165. /* B-direct prediction */
  166. dmv0 = 0; dmv1 = 0;
  167. TRY_BIDIR( dmv0, dmv1, 0 );
  168. }
  169. /* L0+L1 prediction */
  170. dmv0 = fenc_lowres_mvs0[(b - p0 - 1) * mb_count + mb_xy];
  171. dmv1 = fenc_lowres_mvs1[(p1 - b - 1) * mb_count + mb_xy];
  172. TRY_BIDIR( dmv0, dmv1, 5 );
  173. #undef TRY_BIDIR
  174. }
  175. lowres_costs[mb_xy] = min( bcost, LOWRES_COST_MASK ) + (list_used << LOWRES_COST_SHIFT);
  176. }
  177. /*
  178. * parallel sum inter costs
  179. *
  180. * global launch dimensions: [256, mb_height]
  181. */
  182. kernel void sum_inter_cost( const global uint16_t *fenc_lowres_costs,
  183. const global uint16_t *inv_qscale_factor,
  184. global int *fenc_row_satds,
  185. global int *frame_stats,
  186. int mb_width,
  187. int bframe_bias,
  188. int b,
  189. int p0,
  190. int p1 )
  191. {
  192. int y = get_global_id( 1 );
  193. int mb_height = get_global_size( 1 );
  194. int row_satds = 0;
  195. int cost_est = 0;
  196. int cost_est_aq = 0;
  197. int intra_mbs = 0;
  198. for( int x = get_global_id( 0 ); x < mb_width; x += get_global_size( 0 ))
  199. {
  200. int mb_xy = x + y * mb_width;
  201. int cost = fenc_lowres_costs[mb_xy] & LOWRES_COST_MASK;
  202. int list = fenc_lowres_costs[mb_xy] >> LOWRES_COST_SHIFT;
  203. int b_frame_score_mb = (x > 0 && x < mb_width - 1 && y > 0 && y < mb_height - 1) || mb_width <= 2 || mb_height <= 2;
  204. if( list == 0 && b_frame_score_mb )
  205. intra_mbs++;
  206. int cost_aq = (cost * inv_qscale_factor[mb_xy] + 128) >> 8;
  207. row_satds += cost_aq;
  208. if( b_frame_score_mb )
  209. {
  210. cost_est += cost;
  211. cost_est_aq += cost_aq;
  212. }
  213. }
  214. local int buffer[256];
  215. int x = get_global_id( 0 );
  216. row_satds = parallel_sum( row_satds, x, buffer );
  217. cost_est = parallel_sum( cost_est, x, buffer );
  218. cost_est_aq = parallel_sum( cost_est_aq, x, buffer );
  219. intra_mbs = parallel_sum( intra_mbs, x, buffer );
  220. if( b != p1 )
  221. // Use floating point math to avoid 32bit integer overflow conditions
  222. cost_est = (int)((float)cost_est * 100.0f / (120.0f + (float)bframe_bias));
  223. if( get_global_id( 0 ) == 0 )
  224. {
  225. fenc_row_satds[y] = row_satds;
  226. atomic_add( frame_stats + COST_EST, cost_est );
  227. atomic_add( frame_stats + COST_EST_AQ, cost_est_aq );
  228. atomic_add( frame_stats + INTRA_MBS, intra_mbs );
  229. }
  230. }