File size: 7,342 Bytes
25b4ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#ifndef _qdq_5_cuh
#define _qdq_5_cuh

#include "qdq_util.cuh"
#include "../../config.h"

#if QMODE_5BIT == 1

// Permutation:
//
// v5555533 33311111  u4444422 22200000  (u, v lsb)
// vbbbbb99 99977777  uaaaaa88 88866666
// vhhhhhff fffddddd  ugggggee eeeccccc
// vnnnnnll llljjjjj  ummmmmkk kkkiiiii
// vtttttrr rrrppppp  usssssqq qqqooooo

__forceinline__ __device__ void shuffle_5bit_32
(
    uint32_t* q,
    int stride
)
{
    uint32_t qa = q[0 * stride];
    uint32_t qb = q[1 * stride];
    uint32_t qc = q[2 * stride];
    uint32_t qd = q[3 * stride];
    uint32_t qe = q[4 * stride];

    // qa: 66555554 44443333  32222211 11100000
    // qb: ccccbbbb baaaaa99  99988888 77777666
    // qc: jiiiiihh hhhggggg  fffffeee eedddddc
    // qd: pppooooo nnnnnmmm  mmlllllk kkkkjjjj
    // qe: vvvvvuuu uuttttts  ssssrrrr rqqqqqpp

    uint32_t qf = qe >> 22;
    qe <<= 8;
    qe |= qd >> 24;
    qd <<= 6;
    qd |= qc >> 26;
    qc <<= 4;
    qc |= qb >> 28;
    qb <<= 2;
    qb |= qa >> 30;

    // qa:   555554 44443333  32222211 11100000
    // qb:   bbbbba aaaa9999  98888877 77766666
    // qc:   hhhhhg ggggffff  feeeeedd dddccccc
    // qd:   nnnnnm mmmmllll  lkkkkkjj jjjiiiii
    // qe:   ttttts ssssrrrr  rqqqqqpp pppooooo
    // qf:                          vv vvvuuuuu

    uint32_t za = 0;
    uint32_t zb = 0;
    uint32_t zc = 0;
    uint32_t zd = 0;
    uint32_t ze = 0;

    for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
    for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
    for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
    for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
    for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }

    // za:  5555533 33311111   4444422 22200000
    // zb:  bbbbb99 99977777   aaaaa88 88866666
    // zc:  hhhhhff fffddddd   gggggee eeeccccc
    // zd:  nnnnnll llljjjjj   mmmmmkk kkkiiiii
    // ze:  tttttrr rrrppppp   sssssqq qqqooooo
    // qf:                          vv vvvuuuuu

    za |= ((qf & 0x001) >> 0) << 15;
    zb |= ((qf & 0x002) >> 1) << 15;
    zc |= ((qf & 0x004) >> 2) << 15;
    zd |= ((qf & 0x008) >> 3) << 15;
    ze |= ((qf & 0x010) >> 4) << 15;
    za |= ((qf & 0x020) >> 5) << 31;
    zb |= ((qf & 0x040) >> 6) << 31;
    zc |= ((qf & 0x080) >> 7) << 31;
    zd |= ((qf & 0x100) >> 8) << 31;
    ze |= ((qf & 0x200) >> 9) << 31;

    // za: v5555533 33311111  u4444422 22200000  (u, v lsb)
    // zb: vbbbbb99 99977777  uaaaaa88 88866666
    // zc: vhhhhhff fffddddd  ugggggee eeeccccc
    // zd: vnnnnnll llljjjjj  ummmmmkk kkkiiiii
    // ze: vtttttrr rrrppppp  usssssqq qqqooooo

    q[0 * stride] = za;
    q[1 * stride] = zb;
    q[2 * stride] = zc;
    q[3 * stride] = zd;
    q[4 * stride] = ze;
}

__forceinline__ __device__ void dequant_5bit_32
(
    const uint32_t q_0,
    const uint32_t q_1,
    const uint32_t q_2,
    const uint32_t q_3,
    const uint32_t q_4,
    half2 (&dq)[16],
    int stride
)
{
    const uint32_t c0 = 0x64006400;
    const half y32_ = __float2half_rn(1.0f / 32.0f);
    const half2 y32 = __halves2half2(y32_, y32_);
    const half z1_  = __float2half_rn(-1024.0f         - 16.0f);
    const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
    const half2 z1  = __halves2half2(z1_,  z1_);
    const half2 z32 = __halves2half2(z32_, z32_);

    uint32_t qa = q_0;
    uint32_t qb = q_1;
    uint32_t qc = q_2;
    uint32_t qd = q_3;
    uint32_t qe = q_4;

    half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1])      + 1024
    half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
    qa >>= 10;
    half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5])      + 1024
    qa >>= 5;
    qa &= 0x00010001;
    half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7])      + 1024
    half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
    qb >>= 10;
    half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11])      + 1024
    qb >>= 4;
    qb &= 0x00020002;
    half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13])      + 1024
    half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
    qc >>= 10;
    half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17])      + 1024
    qc >>= 3;
    qc &= 0x00040004;
    half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19])      + 1024
    half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
    qd >>= 10;
    half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23])      + 1024
    qd >>= 2;
    qd &= 0x00080008;
    half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25])      + 1024
    half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
    qe >>= 10;
    half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29])      + 1024
    qe >>= 1;
    qe &= 0x00100010;
    half2_uint32 q15((qa | qb | qc | qd | qe) | c0);

    dq[ 0] = __hadd2( q0.as_half2, z1);
    dq[ 1] = __hfma2( q1.as_half2, y32, z32);
    dq[ 2] = __hadd2( q2.as_half2, z1);
    dq[ 3] = __hadd2( q3.as_half2, z1);
    dq[ 4] = __hfma2( q4.as_half2, y32, z32);
    dq[ 5] = __hadd2( q5.as_half2, z1);
    dq[ 6] = __hadd2( q6.as_half2, z1);
    dq[ 7] = __hfma2( q7.as_half2, y32, z32);
    dq[ 8] = __hadd2( q8.as_half2, z1);
    dq[ 9] = __hadd2( q9.as_half2, z1);
    dq[10] = __hfma2(q10.as_half2, y32, z32);
    dq[11] = __hadd2(q11.as_half2, z1);
    dq[12] = __hadd2(q12.as_half2, z1);
    dq[13] = __hfma2(q13.as_half2, y32, z32);
    dq[14] = __hadd2(q14.as_half2, z1);
    dq[15] = __hadd2(q15.as_half2, z1);
}

#else

__forceinline__ __device__ void shuffle_5bit_32
(
    uint32_t* q,
    int stride
)
{
}

__forceinline__ __device__ void dequant_5bit_32
(
    const uint32_t q_0,
    const uint32_t q_1,
    const uint32_t q_2,
    const uint32_t q_3,
    const uint32_t q_4,
    half2 (&dq)[16],
    int stride
)
{
    half dqh[32];
    for (int i = 0; i <  6; i++) dqh[     i] = dq_ns(exb(     q_0, i * 5    , 0x1f), 16);
                                 dqh[ 6    ] = dq_ns(exb(q_1, q_0,        30, 0x1f), 16);
    for (int i = 0; i <  5; i++) dqh[ 7 + i] = dq_ns(exb(     q_1, i * 5 + 3, 0x1f), 16);
                                 dqh[12    ] = dq_ns(exb(q_2, q_1,        28, 0x1f), 16);
    for (int i = 0; i <  6; i++) dqh[13 + i] = dq_ns(exb(     q_2, i * 5 + 1, 0x1f), 16);
                                 dqh[19    ] = dq_ns(exb(q_3, q_2,        31, 0x1f), 16);
    for (int i = 0; i <  5; i++) dqh[20 + i] = dq_ns(exb(     q_3, i * 5 + 4, 0x1f), 16);
                                 dqh[25    ] = dq_ns(exb(q_4, q_3,        29, 0x1f), 16);
    for (int i = 0; i <  6; i++) dqh[26 + i] = dq_ns(exb(     q_4, i * 5 + 2, 0x1f), 16);

    for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}

#endif

#endif