mrsu0994 commited on
Commit
154f182
·
1 Parent(s): 3bf15e2

upload f5-tts source

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. BUILD +44 -0
  3. Dockerfile +34 -0
  4. WORKSPACE +154 -0
  5. alphabet.txt +97 -0
  6. app.py +145 -0
  7. bazelisk-linux-amd64 +3 -0
  8. build_ext.sh +3 -0
  9. extract_tacotrons_model.py +8 -0
  10. extract_wavegru_model.py +12 -0
  11. inference.py +90 -0
  12. mono_tts_cbhg_small_0700000.ckpt +3 -0
  13. packages.txt +7 -0
  14. pooch.py +10 -0
  15. requirements.txt +13 -0
  16. sparse_matmul/BUILD +22 -0
  17. sparse_matmul/compute/BUILD +88 -0
  18. sparse_matmul/compute/ar_inputs.h +37 -0
  19. sparse_matmul/compute/gru_gates.h +214 -0
  20. sparse_matmul/compute/gru_gates_arm.h +288 -0
  21. sparse_matmul/compute/gru_gates_avx_fixed.h +348 -0
  22. sparse_matmul/compute/gru_gates_generic.h +97 -0
  23. sparse_matmul/compute/gru_gates_test.cc +164 -0
  24. sparse_matmul/compute/kernels_arm.h +0 -0
  25. sparse_matmul/compute/kernels_avx.h +601 -0
  26. sparse_matmul/compute/kernels_generic.h +273 -0
  27. sparse_matmul/compute/matmul.h +199 -0
  28. sparse_matmul/compute/matmul_fixed_avx2.cc +235 -0
  29. sparse_matmul/compute/matmul_fixed_avx2.h +49 -0
  30. sparse_matmul/compute/matmul_generic.cc +122 -0
  31. sparse_matmul/compute/matmul_generic.h +41 -0
  32. sparse_matmul/compute/thread_bounds.cc +106 -0
  33. sparse_matmul/compute/thread_bounds.h +74 -0
  34. sparse_matmul/layers/BUILD +146 -0
  35. sparse_matmul/layers/csr_blocksparse_matrix.h +835 -0
  36. sparse_matmul/layers/csrblocksparse_test.cc +977 -0
  37. sparse_matmul/layers/errno_mapping.cc +195 -0
  38. sparse_matmul/layers/errno_mapping.h +29 -0
  39. sparse_matmul/layers/masked_sparse_matrix.h +206 -0
  40. sparse_matmul/layers/read_array_ifstream.h +66 -0
  41. sparse_matmul/layers/sparse_linear_layer.h +365 -0
  42. sparse_matmul/layers/sparse_linear_layer_test.cc +187 -0
  43. sparse_matmul/layers/status_macros.h +34 -0
  44. sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz +3 -0
  45. sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz +3 -0
  46. sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz +3 -0
  47. sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz +3 -0
  48. sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz +3 -0
  49. sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz +3 -0
  50. sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.bin filter=lfs diff=lfs merge=lfs -text
37
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
38
+ *.model filter=lfs diff=lfs merge=lfs -text
39
+ bazelisk-linux-amd64 filter=lfs diff=lfs merge=lfs -text
BUILD ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [internal] load cc_fuzz_target.bzl
2
+ # [internal] load cc_proto_library.bzl
3
+ # [internal] load android_cc_test:def.bzl
4
+
5
+ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
6
+
7
+ package(default_visibility = [":__subpackages__"])
8
+
9
+ licenses(["notice"])
10
+
11
+ # To run all cc_tests in this directory:
12
+ # bazel test //:all
13
+
14
+ # [internal] Command to run dsp_util_android_test.
15
+
16
+ # [internal] Command to run lyra_integration_android_test.
17
+
18
+ exports_files(
19
+ srcs = [
20
+ "wavegru_mod.cc",
21
+ ],
22
+ )
23
+
24
+ pybind_extension(
25
+ name = "wavegru_mod", # This name is not actually created!
26
+ srcs = ["wavegru_mod.cc"],
27
+ deps = [
28
+ "//sparse_matmul",
29
+ ],
30
+ )
31
+
32
+ py_library(
33
+ name = "wavegru_mod",
34
+ data = [":wavegru_mod.so"],
35
+ )
36
+
37
+ py_binary(
38
+ name = "wavegru",
39
+ srcs = ["wavegru.py"],
40
+ deps = [
41
+ ":wavegru_mod"
42
+ ],
43
+ )
44
+
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM us-docker.pkg.dev/colab-images/public/runtime:latest
5
+
6
+ RUN apt update; apt install libsndfile1-dev make autoconf automake libtool gcc pkg-config -y python3-dev
7
+
8
+ WORKDIR /code
9
+
10
+ COPY ./requirements.txt /code/requirements.txt
11
+
12
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
13
+
14
+ # Set up a new user named "user" with user ID 1000
15
+ RUN useradd -m -u 1000 user
16
+
17
+ # Switch to the "user" user
18
+ USER user
19
+
20
+ # Set home to the user's home directory
21
+ ENV HOME=/home/user \
22
+ PATH=/home/user/.local/bin:$PATH
23
+
24
+ # Set the working directory to the user's home directory
25
+ WORKDIR $HOME/app
26
+
27
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
28
+ COPY --chown=user . $HOME/app
29
+
30
+ RUN bash ./build_ext.sh
31
+ EXPOSE 7860
32
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
33
+
34
+ ENTRYPOINT ["python", "app.py"]
WORKSPACE ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################
2
+ # Platform Independent #
3
+ ########################
4
+
5
+ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
6
+ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
7
+
8
+ # GoogleTest/GoogleMock framework.
9
+ git_repository(
10
+ name = "com_google_googletest",
11
+ remote = "https://github.com/google/googletest.git",
12
+ tag = "release-1.10.0",
13
+ )
14
+
15
+ # Google benchmark.
16
+ http_archive(
17
+ name = "com_github_google_benchmark",
18
+ urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
19
+ strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
20
+ sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
21
+ )
22
+
23
+ # proto_library, cc_proto_library, and java_proto_library rules implicitly
24
+ # depend on @com_google_protobuf for protoc and proto runtimes.
25
+ # This statement defines the @com_google_protobuf repo.
26
+ git_repository(
27
+ name = "com_google_protobuf",
28
+ remote = "https://github.com/protocolbuffers/protobuf.git",
29
+ tag = "v3.15.4",
30
+ )
31
+
32
+ load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
33
+ protobuf_deps()
34
+
35
+ # Google Abseil Libs
36
+ git_repository(
37
+ name = "com_google_absl",
38
+ remote = "https://github.com/abseil/abseil-cpp.git",
39
+ branch = "lts_2020_09_23",
40
+ )
41
+
42
+ # Filesystem
43
+ # The new_* prefix is used because it is not a bazel project and there is
44
+ # no BUILD file in that repo.
45
+ FILESYSTEM_BUILD = """
46
+ cc_library(
47
+ name = "filesystem",
48
+ hdrs = glob(["include/ghc/*"]),
49
+ visibility = ["//visibility:public"],
50
+ )
51
+ """
52
+
53
+ new_git_repository(
54
+ name = "gulrak_filesystem",
55
+ remote = "https://github.com/gulrak/filesystem.git",
56
+ tag = "v1.3.6",
57
+ build_file_content = FILESYSTEM_BUILD
58
+ )
59
+
60
+ # Audio DSP
61
+ git_repository(
62
+ name = "com_google_audio_dsp",
63
+ remote = "https://github.com/google/multichannel-audio-tools.git",
64
+ # There are no tags for this repo, we are synced to bleeding edge.
65
+ branch = "master",
66
+ repo_mapping = {
67
+ "@com_github_glog_glog" : "@com_google_glog"
68
+ }
69
+ )
70
+
71
+
72
+ http_archive(
73
+ name = "pybind11_bazel",
74
+ strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
75
+ urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
76
+ )
77
+ # We still require the pybind library.
78
+ http_archive(
79
+ name = "pybind11",
80
+ build_file = "@pybind11_bazel//:pybind11.BUILD",
81
+ strip_prefix = "pybind11-2.10.0",
82
+ urls = ["https://github.com/pybind/pybind11/archive/v2.10.0.tar.gz"],
83
+ )
84
+ load("@pybind11_bazel//:python_configure.bzl", "python_configure")
85
+ python_configure(name = "local_config_python")
86
+
87
+
88
+
89
+ # Transitive dependencies of Audio DSP.
90
+ http_archive(
91
+ name = "eigen_archive",
92
+ build_file = "eigen.BUILD",
93
+ sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
94
+ strip_prefix = "eigen-eigen-049af2f56331",
95
+ urls = [
96
+ "http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
97
+ "https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
98
+ ],
99
+ )
100
+
101
+ http_archive(
102
+ name = "fft2d",
103
+ build_file = "fft2d.BUILD",
104
+ sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
105
+ urls = [
106
+ "http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
107
+ ],
108
+ )
109
+
110
+ # Google logging
111
+ git_repository(
112
+ name = "com_google_glog",
113
+ remote = "https://github.com/google/glog.git",
114
+ tag = "v0.5.0"
115
+ )
116
+ # Dependency for glog
117
+ git_repository(
118
+ name = "com_github_gflags_gflags",
119
+ remote = "https://github.com/mchinen/gflags.git",
120
+ branch = "android_linking_fix"
121
+ )
122
+
123
+ # Bazel/build rules
124
+
125
+ http_archive(
126
+ name = "bazel_skylib",
127
+ urls = [
128
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
129
+ "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
130
+ ],
131
+ sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
132
+ )
133
+ load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
134
+ bazel_skylib_workspace()
135
+
136
+ http_archive(
137
+ name = "rules_android",
138
+ sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
139
+ strip_prefix = "rules_android-0.1.1",
140
+ urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
141
+ )
142
+
143
+ # Google Maven Repository
144
+ GMAVEN_TAG = "20180625-1"
145
+
146
+ http_archive(
147
+ name = "gmaven_rules",
148
+ strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
149
+ url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
150
+ )
151
+
152
+ load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
153
+
154
+ gmaven_rules()
alphabet.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _
2
+
3
+
4
+ !
5
+ ,
6
+ .
7
+ :
8
+ ?
9
+ a
10
+ b
11
+ c
12
+ d
13
+ e
14
+ g
15
+ h
16
+ i
17
+ k
18
+ l
19
+ m
20
+ n
21
+ o
22
+ p
23
+ q
24
+ r
25
+ s
26
+ t
27
+ u
28
+ v
29
+ x
30
+ y
31
+ à
32
+ á
33
+ â
34
+ ã
35
+ è
36
+ é
37
+ ê
38
+ ì
39
+ í
40
+ ò
41
+ ó
42
+ ô
43
+ õ
44
+ ù
45
+ ú
46
+ ý
47
+ ă
48
+ đ
49
+ ĩ
50
+ ũ
51
+ ơ
52
+ ư
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+ ế
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## build wavegru-cpp
2
+ # import os
3
+ # os.system("./bazelisk-linux-amd64 clean --expunge")
4
+ # os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
5
+
6
+ # install espeak
7
+ import os
8
+ import re
9
+ import unicodedata
10
+
11
+ import regex
12
+
13
+ if not os.path.isfile("./wavegru_mod.so"):
14
+ os.system("bash ./build_ext.sh")
15
+
16
+ import gradio as gr
17
+
18
+ from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
19
+ from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
20
+
21
+ alphabet, tacotron_net, tacotron_config = load_tacotron_model(
22
+ "./alphabet.txt", "./tacotron.toml", "./mono_tts_cbhg_small_0700000.ckpt"
23
+ )
24
+
25
+ wavegru_config, wavegru_net = load_wavegru_net(
26
+ "./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_0400000.ckpt"
27
+ )
28
+
29
+ wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
30
+ wavecpp = load_wavegru_cpp(wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1])
31
+
32
+
33
+ space_re = regex.compile(r"\s+")
34
+ number_re = regex.compile("([0-9]+)")
35
+ digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
36
+ num_re = regex.compile(r"([0-9.,]*[0-9])")
37
+ alphabet_ = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
38
+ keep_text_and_num_re = regex.compile(rf"[^\s{alphabet_}.,0-9]")
39
+ keep_text_re = regex.compile(rf"[^\s{alphabet_}]")
40
+
41
+
42
+ def read_number(num: str) -> str:
43
+ if len(num) == 1:
44
+ return digits[int(num)]
45
+ elif len(num) == 2 and num.isdigit():
46
+ n = int(num)
47
+ end = digits[n % 10]
48
+ if n == 10:
49
+ return "mười"
50
+ if n % 10 == 5:
51
+ end = "lăm"
52
+ if n % 10 == 0:
53
+ return digits[n // 10] + " mươi"
54
+ elif n < 20:
55
+ return "mười " + end
56
+ else:
57
+ if n % 10 == 1:
58
+ end = "mốt"
59
+ return digits[n // 10] + " mươi " + end
60
+ elif len(num) == 3 and num.isdigit():
61
+ n = int(num)
62
+ if n % 100 == 0:
63
+ return digits[n // 100] + " trăm"
64
+ elif num[1] == "0":
65
+ return digits[n // 100] + " trăm lẻ " + digits[n % 100]
66
+ else:
67
+ return digits[n // 100] + " trăm " + read_number(num[1:])
68
+ elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
69
+ n = int(num)
70
+ n1 = n // 1000
71
+ return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
72
+ elif "," in num:
73
+ n1, n2 = num.split(",")
74
+ return read_number(n1) + " phẩy " + read_number(n2)
75
+ elif "." in num:
76
+ parts = num.split(".")
77
+ if len(parts) == 2:
78
+ if parts[1] == "000":
79
+ return read_number(parts[0]) + " ngàn"
80
+ elif parts[1].startswith("00"):
81
+ end = digits[int(parts[1][2:])]
82
+ return read_number(parts[0]) + " ngàn lẻ " + end
83
+ else:
84
+ return read_number(parts[0]) + " ngàn " + read_number(parts[1])
85
+ elif len(parts) == 3:
86
+ return (
87
+ read_number(parts[0])
88
+ + " triệu "
89
+ + read_number(parts[1])
90
+ + " ngàn "
91
+ + read_number(parts[2])
92
+ )
93
+ return num
94
+
95
+
96
+ def normalize_text(text):
97
+ # lowercase
98
+ text = text.lower()
99
+ # unicode normalize
100
+ text = unicodedata.normalize("NFKC", text)
101
+ text = text.replace(".", ". ")
102
+ text = text.replace(",", ", ")
103
+ text = text.replace(";", "; ")
104
+ text = text.replace(":", ": ")
105
+ text = text.replace("!", "! ")
106
+ text = text.replace("?", "? ")
107
+ text = text.replace("(", "( ")
108
+
109
+ text = num_re.sub(r" \1 ", text)
110
+ words = text.split()
111
+ words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
112
+ text = " ".join(words)
113
+
114
+ # remove redundant spaces
115
+ text = re.sub(r"\s+", " ", text)
116
+ # remove leading and trailing spaces
117
+ text = text.strip()
118
+ return text
119
+
120
+
121
+ def speak(text):
122
+ text = normalize_text(text)
123
+ mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
124
+ y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
125
+ return 24_000, y
126
+
127
+
128
+ title = "WaveGRU-TTS"
129
+ description = "WaveGRU text-to-speech demo."
130
+
131
+ gr.Interface(
132
+ fn=speak,
133
+ inputs="text",
134
+ examples=[
135
+ "Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
136
+ "Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du.",
137
+ "Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
138
+ "Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học lớn c���a Việt Nam trong thời phong kiến.",
139
+ "Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được, trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
140
+ ],
141
+ outputs="audio",
142
+ title=title,
143
+ description=description,
144
+ theme="default",
145
+ ).launch()
bazelisk-linux-amd64 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:231ec5ca8115e94c75a1f4fbada1a062b48822ca04f21f26e4cb1cd8973cd458
3
+ size 5152768
build_ext.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ chmod +x ./bazelisk-linux-amd64
2
+ USE_BAZEL_VERSION=5.0.0 ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native
3
+ cp -f bazel-bin/wavegru_mod.so .
extract_tacotrons_model.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import jax
4
+
5
+ dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb"))
6
+ del dic["optim_state_dict"]
7
+ dic = jax.device_get(dic)
8
+ pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb"))
extract_wavegru_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import jax
4
+
5
+ dic = pickle.load(
6
+ open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb")
7
+ )
8
+ dic = jax.device_get(dic)
9
+ del dic["optim_state_dict"]
10
+ pickle.dump(
11
+ dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb")
12
+ )
inference.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import librosa
6
+ import numpy as np
7
+ import pax
8
+
9
+ # from text import english_cleaners
10
+ from utils import (
11
+ create_tacotron_model,
12
+ load_tacotron_ckpt,
13
+ load_tacotron_config,
14
+ load_wavegru_ckpt,
15
+ load_wavegru_config,
16
+ )
17
+ from wavegru import WaveGRU
18
+
19
+ # os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "./espeak/usr/lib/libespeak-ng.so.1.1.51"
20
+ # from phonemizer.backend import EspeakBackend
21
+ # backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
22
+
23
+
24
+ def load_tacotron_model(alphabet_file, config_file, model_file):
25
+ """load tacotron model to memory"""
26
+ with open(alphabet_file, "r", encoding="utf-8") as f:
27
+ alphabet = f.read().split("\n")
28
+
29
+ config = load_tacotron_config(config_file)
30
+ net = create_tacotron_model(config)
31
+ _, net, _ = load_tacotron_ckpt(net, None, model_file)
32
+ net = net.eval()
33
+ net = jax.device_put(net)
34
+ return alphabet, net, config
35
+
36
+
37
+ tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=2400))
38
+
39
+
40
+ def text_to_mel(net, text, alphabet, config):
41
+ """convert text to mel spectrogram"""
42
+ # text = english_cleaners(text)
43
+ # text = backend.phonemize([text], strip=True)[0]
44
+ text = text + config["END_CHARACTER"]
45
+ text = text + config["PAD"] * (100 - (len(text) % 100))
46
+ tokens = []
47
+ for c in text:
48
+ if c in alphabet:
49
+ tokens.append(alphabet.index(c))
50
+ tokens = jnp.array(tokens, dtype=jnp.int32)
51
+ mel = tacotron_inference_fn(net, tokens[None])
52
+ return mel
53
+
54
+
55
+ def load_wavegru_net(config_file, model_file):
56
+ """load wavegru to memory"""
57
+ config = load_wavegru_config(config_file)
58
+ net = WaveGRU(
59
+ mel_dim=config["mel_dim"],
60
+ rnn_dim=config["rnn_dim"],
61
+ upsample_factors=config["upsample_factors"],
62
+ has_linear_output=True,
63
+ )
64
+ _, net, _ = load_wavegru_ckpt(net, None, model_file)
65
+ net = net.eval()
66
+ net = jax.device_put(net)
67
+ return config, net
68
+
69
+
70
+ wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
71
+
72
+
73
+ def mel_to_wav(net, netcpp, mel, config):
74
+ """convert mel to wav"""
75
+ if len(mel.shape) == 2:
76
+ mel = mel[None]
77
+ pad = config["num_pad_frames"] // 2 + 2
78
+ mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
79
+ ft = wavegru_inference(net, mel)
80
+ ft = jax.device_get(ft[0])
81
+ wav = netcpp.inference(ft, 1.0)
82
+ wav = np.array(wav)
83
+ wav = librosa.mu_expand(wav - 127, mu=255)
84
+ wav = librosa.effects.deemphasis(wav, coef=0.86)
85
+ wav = wav * 2.0
86
+ wav = wav / max(1.0, np.max(np.abs(wav)))
87
+ wav = wav * 2**15
88
+ wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
89
+ wav = wav.astype(np.int16)
90
+ return wav
mono_tts_cbhg_small_0700000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94a3cf9879f6c71ed21a6569f6f8167a8f4990e46b036b5f8196a16ea14fcb7e
3
+ size 53480857
packages.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ libsndfile1-dev
2
+ make
3
+ autoconf
4
+ automake
5
+ libtool
6
+ gcc
7
+ pkg-config
pooch.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ def os_cache(x):
2
+ return x
3
+
4
+
5
+ def create(*args, **kwargs):
6
+ class T:
7
+ def load_registry(self, *args, **kwargs):
8
+ return None
9
+
10
+ return T()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ inflect
2
+ jax
3
+ jaxlib
4
+ jinja2
5
+ librosa==0.9.0
6
+ numpy
7
+ pax3 @ git+https://github.com/ntt123/pax.git
8
+ pyyaml
9
+ toml
10
+ unidecode
11
+ phonemizer
12
+ gradio
13
+ setuptools
sparse_matmul/BUILD ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [internal] load placeholder
2
+
3
+ licenses(["notice"])
4
+
5
+ cc_library(
6
+ name = "sparse_matmul",
7
+ hdrs = [
8
+ "sparse_matmul.h",
9
+ ],
10
+ visibility = ["//visibility:public"],
11
+ deps = [
12
+ "//sparse_matmul/compute:gru_gates",
13
+ "//sparse_matmul/layers:layer",
14
+ "//sparse_matmul/layers:matrix",
15
+ "//sparse_matmul/layers:utils",
16
+ "//sparse_matmul/numerics:fast_transcendentals",
17
+ "//sparse_matmul/numerics:types",
18
+ "//sparse_matmul/os:coop_threads",
19
+ "//sparse_matmul/vector:cache_aligned_vector",
20
+ ], # internal :sparse_matmul deps placeholder
21
+ )
22
+
sparse_matmul/compute/BUILD ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Low-level computation code, including generic and architecture-specific
2
+ # variants.
3
+
4
+ licenses(["notice"])
5
+
6
+ cc_library(
7
+ name = "gru_gates",
8
+ srcs = [
9
+ "ar_inputs.h",
10
+ "gru_gates_arm.h",
11
+ "gru_gates_avx_fixed.h",
12
+ "gru_gates_generic.h",
13
+ ],
14
+ hdrs = ["gru_gates.h"],
15
+ visibility = [
16
+ "//visibility:public",
17
+ ],
18
+ deps = [
19
+ ":matmul",
20
+ "//sparse_matmul/numerics:fast_transcendentals",
21
+ "//sparse_matmul/numerics:types",
22
+ "//sparse_matmul/vector:cache_aligned_vector",
23
+ ],
24
+ )
25
+
26
+ cc_library(
27
+ name = "kernels",
28
+ srcs = [
29
+ "kernels_arm.h",
30
+ "kernels_avx.h",
31
+ ],
32
+ hdrs = [
33
+ "kernels_generic.h",
34
+ ],
35
+ visibility = [
36
+ "//sparse_matmul:__subpackages__",
37
+ ],
38
+ deps = [
39
+ "//sparse_matmul/numerics:fast_transcendentals",
40
+ "//sparse_matmul/numerics:types",
41
+ ],
42
+ )
43
+
44
+ cc_library(
45
+ name = "matmul",
46
+ srcs = [
47
+ "matmul_fixed_avx2.cc",
48
+ "matmul_fixed_avx2.h",
49
+ "matmul_generic.cc",
50
+ "matmul_generic.h",
51
+ ],
52
+ hdrs = [
53
+ "matmul.h",
54
+ ],
55
+ visibility = [
56
+ "//sparse_matmul:__subpackages__",
57
+ ],
58
+ deps = [
59
+ "//sparse_matmul/numerics:types",
60
+ "@com_google_absl//absl/time",
61
+ ],
62
+ )
63
+
64
+ cc_library(
65
+ name = "thread_bounds",
66
+ srcs = ["thread_bounds.cc"],
67
+ hdrs = ["thread_bounds.h"],
68
+ visibility = [
69
+ "//sparse_matmul:__subpackages__",
70
+ ],
71
+ deps = [
72
+ "@com_google_glog//:glog",
73
+ ],
74
+ )
75
+
76
+ cc_test(
77
+ name = "gru_gates_test",
78
+ size = "small",
79
+ srcs = [
80
+ "gru_gates_test.cc",
81
+ ],
82
+ deps = [
83
+ ":gru_gates",
84
+ "@com_google_absl//absl/memory",
85
+ "@com_google_absl//absl/types:span",
86
+ "@com_google_googletest//:gtest_main",
87
+ ],
88
+ )
sparse_matmul/compute/ar_inputs.h ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
19
+
20
+ namespace csrblocksparse {
21
+
22
+ // Possible numbers of Autoregressive inputs.
23
+ // TODO(b/188702959): Generalize to any non-negative integer value?
24
+ enum class ARInputsMode {
25
+ // There are no autoregressive inputs. Inputs to the GRU gates are strictly
26
+ // from the gate-recurrent matmul and other unrelated inputs.
27
+ k0ARInputs,
28
+ // Two autoregressive inputs, such as coarse and fine for WaveRNN.
29
+ k2ARInputs,
30
+ // Three autoregressive inputs, such as prev coarse and fine plus current
31
+ // coarse for WaveRNN.
32
+ k3ARInputs,
33
+ };
34
+
35
+ } // namespace csrblocksparse
36
+
37
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
sparse_matmul/compute/gru_gates.h ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
19
+
20
+ #include <cstdint>
21
+ #include <vector>
22
+
23
+ // IWYU pragma: begin_exports
24
+ #include "sparse_matmul/compute/ar_inputs.h"
25
+ #include "sparse_matmul/compute/gru_gates_arm.h"
26
+ #include "sparse_matmul/compute/gru_gates_avx_fixed.h"
27
+ #include "sparse_matmul/compute/gru_gates_generic.h"
28
+ #include "sparse_matmul/compute/matmul.h"
29
+ #include "sparse_matmul/numerics/fixed_types.h"
30
+ #include "sparse_matmul/numerics/type_utils.h"
31
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
32
+ // IWYU pragma: end_exports
33
+
34
+ namespace csrblocksparse {
35
+
36
+ // The master template is really a catch-all for the unimplemented cases to
37
+ // run the generics.
38
+ template <typename GRUStateType, typename InputType, typename SampleType = void>
39
+ class GruGates : public MatmulBase {
40
+ public:
41
+ using SampleWeightType = float;
42
+ static constexpr int kSIMDWidth = kGenericSIMDWidth;
43
+
44
+ // Generic GRU function covers all uses for WaveRNN-like architectures and
45
+ // conditioning.
46
+ // Controlled by template parameters thus:
47
+ // - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so
48
+ // |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|,
49
+ // |ar_2_weights| are ignored.
50
+ // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied
51
+ // by |ar_01_weights| and added to the (conditioning) input.
52
+ // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by
53
+ // |ar_2_weights| and added to the other two |ar_inputs| (and added to the
54
+ // conditioning input).
55
+ // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
56
+ // recurrent input that must be added to |*gru_recurrent_ptr|.
57
+ // - |num_replicas| determines the number of duplicates of the output to be
58
+ // written, separated by |replica_stride|.
59
+ // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
60
+ // thread.
61
+ //
62
+ // Previous state is read from |*gru_state_ptr| and the new state is written
63
+ // to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)).
64
+ template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
65
+ bool kSplitGates = false>
66
+ void GruWithARInput(int start, int end, int state_size,
67
+ const InputType* gru_recurrent_ptr,
68
+ const InputType* input_ptr, GRUStateType* gru_state_ptr,
69
+ const SampleType* ar_sample0 = nullptr,
70
+ const SampleType* ar_sample1 = nullptr,
71
+ const SampleWeightType* ar_01_weights = nullptr,
72
+ int num_replicas = 1, int replica_stride = 0,
73
+ const SampleType* ar_sample2 = nullptr,
74
+ const SampleWeightType* ar_2_weights = nullptr,
75
+ const InputType* gru_recurrent_other_ptr = nullptr) {
76
+ CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
77
+ GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
78
+ kInputsMode, kSplitGates>(
79
+ start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
80
+ input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0,
81
+ ar_sample1, ar_sample2);
82
+ }
83
+
84
+ // No AR inputs, no split gates, no batching, no replicated outputs.
85
+ // TODO(b/188702959): Redirect conditioning GRU here, removing code from
86
+ // gru_layer.h.
87
+ // Copy to specializations.
88
+ void PlainGru(int start, int end, int state_size,
89
+ const InputType* gru_recurrent_ptr, const InputType* input_ptr,
90
+ GRUStateType* gru_state_ptr) {
91
+ GruWithARInput<ARInputsMode::k0ARInputs>(
92
+ start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr);
93
+ }
94
+ };
95
+
96
+ #if defined __ARM_NEON || defined __aarch64__
97
+ // Partial specialization for float.
98
+ template <>
99
+ class GruGates<float, float, float> : public MatmulBase {
100
+ public:
101
+ static constexpr int kSIMDWidth = kNeonSIMDWidth;
102
+
103
+ // Generic GRU function covers all uses for WaveRNN-like architectures and
104
+ // conditioning.
105
+ template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
106
+ bool kSplitGates = false>
107
+ void GruWithARInput(int start, int end, int state_size,
108
+ const float* gru_recurrent_data, const float* input_data,
109
+ float* gru_state_data, const float* ar_sample0 = nullptr,
110
+ const float* ar_sample1 = nullptr,
111
+ const float* ar_01_weights = nullptr,
112
+ int num_replicas = 1, int replica_stride = 0,
113
+ const float* ar_sample2 = nullptr,
114
+ const float* ar_2_weights = nullptr,
115
+ const float* gru_recurrent_other_data = nullptr) {
116
+ DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
117
+ GoThroughGatesFloat<kInputsMode, kSplitGates>(
118
+ start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
119
+ input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
120
+ ar_sample1, ar_sample2);
121
+ }
122
+ };
123
+ #endif // defined __ARM_NEON || defined __aarch64__
124
+
125
+ // Partial specialization for fixed types. The sample weights are always float
126
+ // whatever the fixed type of the other weights.
127
+ template <int kGRUStateBits, int kInputBits, int kSampleBits>
128
+ class GruGates<fixed16<kGRUStateBits>, fixed32<kInputBits>,
129
+ fixed16<kSampleBits>> : public MatmulBase {
130
+ public:
131
+ #if defined __ARM_NEON || defined __aarch64__
132
+ static constexpr int kSIMDWidth = kNeonSIMDWidth;
133
+ #elif defined __AVX2__
134
+ static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2;
135
+ #else // Generic case.
136
+ static constexpr int kSIMDWidth = kGenericSIMDWidth;
137
+ #endif // __ARM_NEON || defined __aarch64__ / __AVX2__
138
+
139
+ using GRUStateType = fixed16<kGRUStateBits>;
140
+ using InputType = fixed32<kInputBits>;
141
+ using SampleType = fixed16<kSampleBits>;
142
+ using SampleWeightType = float;
143
+ static constexpr int kInputMantissaBits = InputType::kMantissaBits;
144
+ static constexpr int kSampleMantissaBits = SampleType::kMantissaBits;
145
+ static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits;
146
+ // Generic GRU function covers all uses for WaveRNN-like architectures and
147
+ // conditioning.
148
+ template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
149
+ bool kSplitGates = false>
150
+ void GruWithARInput(int start, int end, int state_size,
151
+ const InputType* gru_recurrent_data,
152
+ const InputType* input_data, GRUStateType* gru_state_data,
153
+ const SampleType* ar_sample0 = nullptr,
154
+ const SampleType* ar_sample1 = nullptr,
155
+ const SampleWeightType* ar_01_weights = nullptr,
156
+ int num_replicas = 1, int replica_stride = 0,
157
+ const SampleType* ar_sample2 = nullptr,
158
+ const SampleWeightType* ar_2_weights = nullptr,
159
+ const InputType* gru_recurrent_other_data = nullptr) {
160
+ #if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__
161
+ const int32_t* gru_recurrent_ptr =
162
+ reinterpret_cast<const int32_t*>(gru_recurrent_data);
163
+ const int32_t* gru_recurrent_other_ptr =
164
+ reinterpret_cast<const int32_t*>(gru_recurrent_other_data);
165
+ const int32_t* input_ptr = reinterpret_cast<const int32_t*>(input_data);
166
+ int16_t* gru_state_ptr = reinterpret_cast<int16_t*>(gru_state_data);
167
+ #if defined __AVX2__
168
+ // The samples are fixed16, but we scale them up here and convert to float
169
+ // so that the product with the QR weights is always on the same scale as
170
+ // InputType, so we don't have to do any more scaling inside.
171
+ const float sample_factor = static_cast<float>(1 << kInputMantissaBits);
172
+ #else
173
+ const float sample_factor = 1.0f;
174
+ #endif
175
+ // AR sample 0 and 1 are packed into a pair because the QR weights are
176
+ // formatted with the weights interleaved for sample 0 and 1.
177
+ std::pair<float, float> ar_sample01;
178
+ float ar_sample2_float = 0.0f;
179
+ if (kInputsMode == ARInputsMode::k2ARInputs ||
180
+ kInputsMode == ARInputsMode::k3ARInputs) {
181
+ ar_sample01 = {static_cast<float>(*ar_sample0) * sample_factor,
182
+ static_cast<float>(*ar_sample1) * sample_factor};
183
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
184
+ ar_sample2_float = static_cast<float>(*ar_sample2) * sample_factor;
185
+ }
186
+ }
187
+ #if defined __AVX2__
188
+ CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
189
+ GruGatesAVXFixed<kInputMantissaBits, kStateMantissaBits, kInputsMode,
190
+ kSplitGates>(
191
+ start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01,
192
+ ar_01_weights, num_replicas, replica_stride, &ar_sample2_float,
193
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
194
+ #else // ARM.
195
+ DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
196
+ GoThroughGatesFixed<GRUStateType, InputType, kInputsMode, kSplitGates>(
197
+ start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
198
+ input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01,
199
+ &ar_sample2_float);
200
+ #endif // __AVX2__ / ARM.
201
+ #else // Generic case.
202
+ CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
203
+ GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
204
+ kInputsMode, kSplitGates>(
205
+ start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
206
+ input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
207
+ ar_sample1, ar_sample2);
208
+ #endif // __ARM_NEON || defined __aarch64__ / __AVX2__
209
+ }
210
+ };
211
+
212
+ } // namespace csrblocksparse
213
+
214
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
sparse_matmul/compute/gru_gates_arm.h ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
19
+
20
+ #if defined __ARM_NEON || defined __aarch64__
21
+ #include <arm_neon.h>
22
+ #endif
23
+ #include <cstdint>
24
+
25
+ #include "sparse_matmul/compute/ar_inputs.h"
26
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
27
+
28
+ namespace csrblocksparse {
29
+
30
+ static constexpr int kNeonSIMDWidth = 4;
31
+
32
+ // ------ Scalar calculation --------
33
+ // See "Efficient Neural Audio Synthesis" for a description of the calculation.
34
+ // https://arxiv.org/abs/1802.08435
35
+ //
36
+ // NOTE:
37
+ // |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|,
38
+ // |coarse_at_sminus1|, |fine_at_sminus1|)
39
+ // |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|)
40
+ //
41
+ // CHEATSHEET:
42
+ // vld1q_f32 = load 4 32-bit floats
43
+ // vmulq_f32(a, b) : return a * b;
44
+ // vaddq_f32(a, b) : return a + b;
45
+ // vmlaq_f32(c, a, b) : return c + a * b;
46
+ // vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3)
47
+ // vsubq_f32(a, b) : return a - b;
48
+ // vst1q_f32 = store 4 32-bit floats
49
+ #if defined __ARM_NEON || defined __aarch64__
50
+
51
+ #if !defined __aarch64__
52
+ // Backport of vpaddq_f32 to ARM32.
53
+ inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) {
54
+ float32x2_t a10 = vget_low_f32(a);
55
+ float32x2_t a32 = vget_high_f32(a);
56
+ float32x2_t b10 = vget_low_f32(b);
57
+ float32x2_t b32 = vget_high_f32(b);
58
+ return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32));
59
+ }
60
+ #endif
61
+
62
+ template <ARInputsMode kInputsMode, bool SplitGates>
63
+ void GoThroughGatesFloat(int start, int end, const float* qr_ptr,
64
+ const float* gru_gates_ptr,
65
+ const float* gru_gates_other_ptr,
66
+ const float* conditioning_ptr, float* gru_h_ptr,
67
+ const float* w_hat, int proj_size,
68
+ const float* coarse_at_sminus1,
69
+ const float* fine_at_sminus1,
70
+ const float* coarse_at_s) {
71
+ // Increment all the pointers to save on pointer arithmetic in the loop.
72
+ conditioning_ptr += start;
73
+ gru_h_ptr += start;
74
+ gru_gates_ptr += start;
75
+ if (SplitGates) {
76
+ DCHECK_NE(gru_gates_other_ptr, nullptr);
77
+ gru_gates_other_ptr += start;
78
+ }
79
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
80
+ DCHECK_NE(qr_ptr, nullptr);
81
+ qr_ptr += 2 * start;
82
+ DCHECK_NE(coarse_at_sminus1, nullptr);
83
+ DCHECK_NE(fine_at_sminus1, nullptr);
84
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
85
+ DCHECK_NE(w_hat, nullptr);
86
+ DCHECK_NE(coarse_at_s, nullptr);
87
+ w_hat += start;
88
+ }
89
+ }
90
+ for (int i = start; i < end; i += kNeonSIMDWidth) {
91
+ float32x4_t reset = vld1q_f32(gru_gates_ptr);
92
+ float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size);
93
+ float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size);
94
+ float32x4_t qr_cell;
95
+ if (SplitGates) {
96
+ reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr));
97
+ update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size));
98
+ cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size));
99
+ }
100
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
101
+ // Setup the sample vector.
102
+ float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1);
103
+ sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1);
104
+ sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3);
105
+
106
+ // All auto types are float32x4_t, auto used to fit statements on one line
107
+ // for readability. Do two rows of QR at once.
108
+ auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample);
109
+ auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample);
110
+ auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
111
+
112
+ auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample);
113
+ auto qr_update_1 =
114
+ vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample);
115
+ auto qr_update = vpaddq_f32(qr_update_0, qr_update_1);
116
+
117
+ auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample);
118
+ auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample);
119
+ qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
120
+
121
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
122
+ float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
123
+ qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
124
+ qr_update =
125
+ vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
126
+ qr_cell =
127
+ vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
128
+ }
129
+ reset = vaddq_f32(reset, qr_reset);
130
+ update = vaddq_f32(update, qr_update);
131
+ }
132
+ auto reset_conditioning = vld1q_f32(conditioning_ptr);
133
+ auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size);
134
+ auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size);
135
+
136
+ reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning));
137
+ update = fast_sigmoid(vaddq_f32(update, update_conditioning));
138
+ if (kInputsMode == ARInputsMode::k0ARInputs) {
139
+ cell = vmulq_f32(reset, cell);
140
+ } else {
141
+ cell = vmlaq_f32(qr_cell, reset, cell);
142
+ }
143
+ auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
144
+
145
+ auto prev_h = vld1q_f32(gru_h_ptr);
146
+ auto diff = vsubq_f32(prev_h, hbar);
147
+ auto new_h = vmlaq_f32(hbar, diff, update);
148
+
149
+ vst1q_f32(gru_h_ptr, new_h);
150
+ // Increment all the pointers.
151
+ conditioning_ptr += kNeonSIMDWidth;
152
+ gru_h_ptr += kNeonSIMDWidth;
153
+ gru_gates_ptr += kNeonSIMDWidth;
154
+ if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
155
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
156
+ qr_ptr += 2 * kNeonSIMDWidth;
157
+ if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
158
+ }
159
+ }
160
+ }
161
+
162
+ // This version should only be used if all of the 32-bit fixed point
163
+ // representations have the same number of mantissa bits.
164
+ // |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are
165
+ // formatted with the weights interleaved for sample 0 and 1. The two samples
166
+ // represent coarse and fine for WaveRNN.
167
+ template <typename GRUStateType, typename GRUMatMulOutType,
168
+ ARInputsMode kInputsMode, bool SplitGates>
169
+ void GoThroughGatesFixed(int start, int end, const float* qr_ptr,
170
+ const int32_t* gru_gates_ptr,
171
+ const int32_t* gru_gates_other_ptr,
172
+ const int32_t* conditioning_ptr, int16_t* gru_h_ptr,
173
+ const float* w_hat, int proj_size,
174
+ const std::pair<float, float>* ar_at_sminus1,
175
+ const float* coarse_at_s) {
176
+ // Increment all the pointers to save on pointer arithmetic in the loop.
177
+ conditioning_ptr += start;
178
+ gru_h_ptr += start;
179
+ gru_gates_ptr += start;
180
+ if (SplitGates) {
181
+ DCHECK_NE(gru_gates_other_ptr, nullptr);
182
+ gru_gates_other_ptr += start;
183
+ }
184
+ float32x4_t sample01;
185
+ float32x4_t w_sample;
186
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
187
+ DCHECK_NE(qr_ptr, nullptr);
188
+ qr_ptr += 2 * start;
189
+ DCHECK_NE(ar_at_sminus1, nullptr);
190
+ sample01 = vdupq_n_f32(ar_at_sminus1->first);
191
+ sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1);
192
+ sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3);
193
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
194
+ DCHECK_NE(w_hat, nullptr);
195
+ DCHECK_NE(coarse_at_s, nullptr);
196
+ w_hat += start;
197
+ w_sample = vdupq_n_f32(*coarse_at_s);
198
+ }
199
+ }
200
+ for (int i = start; i < end; i += kNeonSIMDWidth) {
201
+ auto reset = vld1q_s32(gru_gates_ptr);
202
+ auto update = vld1q_s32(gru_gates_ptr + proj_size);
203
+ // vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32
204
+ auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size);
205
+ if (SplitGates) {
206
+ reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr));
207
+ update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size));
208
+ cell_int =
209
+ vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size));
210
+ }
211
+ float32x4_t cell =
212
+ vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits);
213
+ float32x4_t qr_cell;
214
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
215
+ // Do two rows of QR at once.
216
+ float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01);
217
+ float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01);
218
+ float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
219
+
220
+ float32x4_t qr_update_0 =
221
+ vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01);
222
+ float32x4_t qr_update_1 =
223
+ vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01);
224
+ float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1);
225
+
226
+ float32x4_t qr_cell_0 =
227
+ vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01);
228
+ float32x4_t qr_cell_1 =
229
+ vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01);
230
+ qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
231
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
232
+ float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
233
+ qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
234
+ qr_update =
235
+ vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
236
+ qr_cell =
237
+ vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
238
+ }
239
+ reset = vaddq_s32(
240
+ reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits));
241
+ update = vaddq_s32(
242
+ update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits));
243
+ }
244
+
245
+ auto reset_conditioning = vld1q_s32(conditioning_ptr);
246
+ auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size);
247
+ float32x4_t cell_conditioning =
248
+ vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size),
249
+ GRUMatMulOutType::kMantissaBits);
250
+
251
+ float32x4_t reset_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
252
+ vaddq_s32(reset, reset_conditioning));
253
+ float32x4_t update_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
254
+ vaddq_s32(update, update_conditioning));
255
+ if (kInputsMode == ARInputsMode::k0ARInputs) {
256
+ cell = vmulq_f32(reset_f32, cell);
257
+ } else {
258
+ cell = vmlaq_f32(qr_cell, reset_f32, cell);
259
+ }
260
+ float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
261
+
262
+ float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)),
263
+ GRUStateType::kMantissaBits);
264
+ float32x4_t diff = vsubq_f32(prev_h, hbar);
265
+ float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32);
266
+
267
+ // vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point
268
+ // vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to
269
+ // convert a 32-bit fixed point value to a 16-bit fixed point value
270
+ vst1_s16(gru_h_ptr,
271
+ vqrshrn_n_s32(
272
+ vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16));
273
+ // Increment all the pointers.
274
+ conditioning_ptr += kNeonSIMDWidth;
275
+ gru_h_ptr += kNeonSIMDWidth;
276
+ gru_gates_ptr += kNeonSIMDWidth;
277
+ if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
278
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
279
+ qr_ptr += 2 * kNeonSIMDWidth;
280
+ if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
281
+ }
282
+ }
283
+ }
284
+ #endif // defined __ARM_NEON || defined __aarch64__
285
+
286
+ } // namespace csrblocksparse
287
+
288
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
sparse_matmul/compute/gru_gates_avx_fixed.h ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
19
+
20
+ #include <cstdint>
21
+ #if defined __AVX2__
22
+ #include <immintrin.h>
23
+ #endif
24
+ #include <vector>
25
+
26
+ #include "sparse_matmul/compute/ar_inputs.h"
27
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
28
+
29
+ namespace csrblocksparse {
30
+
31
+ #if defined __AVX2__
32
+
33
+ constexpr int kAVX2SIMDWidth = 8;
34
+
35
+ // Loads 8x fixed32 from |ptr0| and adds to |input|.
36
+ // If |kTwoInputs|, also loads from |ptr1| and adds that as well.
37
+ // Returns the 2 or 3-way sum.
38
+ template <bool kTwoInputs>
39
+ inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1,
40
+ const __m256i& input) {
41
+ __m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
42
+ if (kTwoInputs) {
43
+ __m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
44
+ data0 = _mm256_add_epi32(data0, data1);
45
+ }
46
+ return _mm256_add_epi32(data0, input);
47
+ }
48
+
49
+ // Loads 8x fixed32 from ptr0.
50
+ // If |kTwoInputs|, also loads from |ptr1| and adds.
51
+ // Multiplies the loaded values by the factor and adds to |input|, which also
52
+ // is converted to float.
53
+ // Returns the sum.
54
+ template <bool kTwoInputs>
55
+ inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1,
56
+ const __m256& float_factor,
57
+ const __m256& input) {
58
+ __m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
59
+ if (kTwoInputs) {
60
+ __m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
61
+ data0 = _mm256_add_epi32(data0, data1);
62
+ }
63
+ __m256 float_result = _mm256_cvtepi32_ps(data0);
64
+ float_result = _mm256_mul_ps(float_result, float_factor);
65
+ return _mm256_add_ps(float_result, input);
66
+ }
67
+
68
+ // Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by
69
+ // |input_pairs|, likewise formatted as 8x floats, alternating between the two
70
+ // AR inputs and sums each pair of results, making 8x float results.
71
+ // If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by
72
+ // |third_input|, which must be formatted as 8x float. The second product is
73
+ // added to the previous result.
74
+ // Returns the sum added to |accumulator|.
75
+ template <bool kThreeInputs>
76
+ inline __m256 MultiplyAddFloat(const __m256& input_pairs,
77
+ const __m256& third_input, const float* ptr0_1,
78
+ const float* ptr2, const __m256& accumulator) {
79
+ __m256 data_pair0 = _mm256_load_ps(ptr0_1);
80
+ __m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8);
81
+ data_pair0 = _mm256_mul_ps(data_pair0, input_pairs);
82
+ data_pair1 = _mm256_mul_ps(data_pair1, input_pairs);
83
+ data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1);
84
+ // Swap the middle 2 64 bit pairs to correct the hadd result.
85
+ data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8);
86
+ if (kThreeInputs) {
87
+ // Load 256 bits (8 x float) of data, then multiply-accumulate.
88
+ data_pair1 = _mm256_load_ps(ptr2);
89
+ data_pair1 = _mm256_mul_ps(data_pair1, third_input);
90
+ data_pair0 = _mm256_add_ps(data_pair0, data_pair1);
91
+ }
92
+ // Add conditioning.
93
+ return _mm256_add_ps(data_pair0, accumulator);
94
+ }
95
+
96
+ // Processes the tanh and the final combination, returns the new GRU state.
97
+ template <int kInputMantissaBits, int kStateMantissaBits, bool kSplitGates>
98
+ inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1,
99
+ const __m256& reset0, const __m256& reset1,
100
+ const __m256& update0, const __m256& update1,
101
+ const int32_t* gate_ptr,
102
+ const int32_t* gate_other_ptr,
103
+ const void* gru_h_ptr) {
104
+ // Multiply the cell gru output and the reset.
105
+ __m256 float_gru0 = LoadMultiplyAddToFloat<kSplitGates>(
106
+ gate_ptr, gate_other_ptr, reset0, cell0);
107
+ __m256 float_gru1 = LoadMultiplyAddToFloat<kSplitGates>(
108
+ gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1,
109
+ cell1);
110
+ // Compute tanh on the result.
111
+ __m256 hbar0, hbar1;
112
+ float_tanh_float<kInputMantissaBits, TM_ORDER4_FLOAT>(float_gru0, float_gru1,
113
+ hbar0, hbar1);
114
+ // Load the 16-bit previous gru state and update.
115
+ __m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr));
116
+ __m256 state_factor =
117
+ _mm256_set1_ps(1.0f / (static_cast<float>(1 << kStateMantissaBits)));
118
+ float_gru0 =
119
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru)));
120
+ float_gru1 = _mm256_cvtepi32_ps(
121
+ _mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1)));
122
+ float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
123
+ float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
124
+ float_gru0 = _mm256_sub_ps(float_gru0, hbar0);
125
+ float_gru1 = _mm256_sub_ps(float_gru1, hbar1);
126
+ float_gru0 = _mm256_mul_ps(float_gru0, update0);
127
+ float_gru1 = _mm256_mul_ps(float_gru1, update1);
128
+ state_factor = _mm256_set1_ps(static_cast<float>(1 << kStateMantissaBits));
129
+ float_gru0 = _mm256_add_ps(float_gru0, hbar0);
130
+ float_gru1 = _mm256_add_ps(float_gru1, hbar1);
131
+ float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
132
+ float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
133
+ return PackFloatsToFixed16(float_gru0, float_gru1);
134
+ }
135
+
136
+ // According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and
137
+ // combines with |input| and |gates*|.
138
+ // With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies
139
+ // by |paired_ar|, likewise formatted as 8x float, but scaled such that the
140
+ // product with pair_weights is on the same scale as |*input| and |*gates0|,
141
+ // and sums each pair result, making 8x float results.
142
+ // If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by
143
+ // |third_ar|, which must be formatted as 8x scaled floats. The second product
144
+ // is added to the previous result.
145
+ // Inputs, 8x fixed32 are loaded from |input|, and added to the total.
146
+ // Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as
147
+ // well.
148
+ // Returns the total sum as a float, but on the scale of |*input|.
149
+ template <bool kTwoGates, ARInputsMode kInputsMode>
150
+ inline __m256 GruInput32ToFloat(const __m256& paired_ar,
151
+ const __m256& third_ar,
152
+ const float* pair_weights,
153
+ const float* third_weights,
154
+ const int32_t* gates0, const int32_t* gates1,
155
+ const int32_t* input) {
156
+ __m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input));
157
+ data32 = LoadAndAddFixed32<kTwoGates>(gates0, gates1, data32);
158
+ __m256 float_data = _mm256_cvtepi32_ps(data32);
159
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
160
+ float_data = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
161
+ paired_ar, third_ar, pair_weights, third_weights, float_data);
162
+ }
163
+ return float_data;
164
+ }
165
+
166
+ // Generic GRU gates function controlled by template parameters thus:
167
+ // - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|.
168
+ // - |kStateBits|: the mantissa_bits in |*gru_state_ptr|.
169
+ // - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so
170
+ // |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are
171
+ // ignored.
172
+ // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by
173
+ // |ar_01_weights| and added to the (conditioning) input.
174
+ // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights|
175
+ // and added to the other two AR inputs (and added to the conditioning input).
176
+ // - |kReplicas| determines the number of duplicates of the output to be
177
+ // written, separated by |replica_stride|. If zero, then the number of
178
+ // replicas is variable and taken from the |replicas| argument.
179
+ // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
180
+ // recurrent input that must be added to |*gru_recurrent_ptr|.
181
+ // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
182
+ // thread.
183
+ //
184
+ // Previous state is read from |*gru_state_ptr| and the new state is written to
185
+ // *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]).
186
+ template <int kInputBits, int kStateBits,
187
+ ARInputsMode kInputsMode = ARInputsMode::k0ARInputs,
188
+ int kReplicas = 1, bool kSplitGates = false>
189
+ inline void GruGatesTemplate(
190
+ int start, int end, int state_size, int replicas, int replica_stride,
191
+ const int32_t* gru_recurrent_ptr, const int32_t* input_ptr,
192
+ const std::pair<float, float>* ar_sample01, const float* ar_01_weights,
193
+ const float* ar_sample2, const float* ar_2_weights,
194
+ const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
195
+ constexpr int kQRIncrement = kAVX2SIMDWidth;
196
+ // Increment all the pointers to save on pointer arithmetic in the loop.
197
+ input_ptr += start;
198
+ gru_state_ptr += start;
199
+ gru_recurrent_ptr += start;
200
+ if (kSplitGates) gru_recurrent_other_ptr += start;
201
+ __m256 ar_2_inputs, ar_3rd_input;
202
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
203
+ ar_01_weights += 2 * start;
204
+ ar_2_inputs = _mm256_castsi256_ps(
205
+ _mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(ar_sample01)));
206
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
207
+ ar_2_weights += start;
208
+ ar_3rd_input = _mm256_set1_ps(*ar_sample2);
209
+ } else {
210
+ ar_3rd_input = {};
211
+ }
212
+ } else {
213
+ ar_2_inputs = {};
214
+ ar_3rd_input = {};
215
+ }
216
+ // The transcendentals handle 2x registers of data at once, so we have to do
217
+ // everything in duplicate.
218
+ for (int i = start; i < end; i += kQRIncrement * 2) {
219
+ // Load 8 pairs of fixed16s for each of reset, update and cell.
220
+ __m256 reset0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
221
+ ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights,
222
+ gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr);
223
+ __m256 reset1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
224
+ ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement,
225
+ ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth,
226
+ gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth);
227
+ float_sigmoid_float<kInputBits>(reset0, reset1);
228
+ __m256 update0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
229
+ ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size,
230
+ ar_2_weights + state_size, gru_recurrent_ptr + state_size,
231
+ gru_recurrent_other_ptr + state_size, input_ptr + state_size);
232
+ __m256 update1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
233
+ ar_2_inputs, ar_3rd_input,
234
+ ar_01_weights + 2 * state_size + 2 * kQRIncrement,
235
+ ar_2_weights + state_size + kQRIncrement,
236
+ gru_recurrent_ptr + state_size + kAVX2SIMDWidth,
237
+ gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth,
238
+ input_ptr + state_size + kAVX2SIMDWidth);
239
+ float_sigmoid_float<kInputBits>(update0, update1);
240
+ __m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256(
241
+ reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size)));
242
+ __m256 cell1 =
243
+ _mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>(
244
+ input_ptr + 2 * state_size + kAVX2SIMDWidth)));
245
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
246
+ cell0 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
247
+ ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size,
248
+ ar_2_weights + 2 * state_size, cell0);
249
+ cell1 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
250
+ ar_2_inputs, ar_3rd_input,
251
+ ar_01_weights + 4 * state_size + 2 * kQRIncrement,
252
+ ar_2_weights + 2 * state_size + kQRIncrement, cell1);
253
+ }
254
+ __m256i gru_state = GRUComputeState<kInputBits, kStateBits, kSplitGates>(
255
+ cell0, cell1, reset0, reset1, update0, update1,
256
+ gru_recurrent_ptr + 2 * state_size,
257
+ gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr);
258
+ if (kReplicas > 0) {
259
+ // With |kReplicas| a template parameter, the compiler will unroll the
260
+ // loop.
261
+ for (int j = 0; j < kReplicas; ++j) {
262
+ _mm256_store_si256(
263
+ reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
264
+ gru_state);
265
+ }
266
+ } else {
267
+ // This loop will not unroll as replicas is variable.
268
+ for (int j = 0; j < replicas; ++j) {
269
+ _mm256_store_si256(
270
+ reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
271
+ gru_state);
272
+ }
273
+ }
274
+ // Increment all the pointers.
275
+ input_ptr += 2 * kAVX2SIMDWidth;
276
+ gru_state_ptr += 2 * kAVX2SIMDWidth;
277
+ gru_recurrent_ptr += 2 * kAVX2SIMDWidth;
278
+ if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth;
279
+ if (kInputsMode != ARInputsMode::k0ARInputs) {
280
+ ar_01_weights += 4 * kQRIncrement;
281
+ if (kInputsMode == ARInputsMode::k3ARInputs)
282
+ ar_2_weights += 2 * kQRIncrement;
283
+ }
284
+ }
285
+ }
286
+
287
+ // Dispatches calls to the GruGatesTemplate function above converting the
288
+ // replicas variable argument to a template parameter to allow the compiler to
289
+ // unroll the write loop.
290
+ // |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are
291
+ // formatted with the weights interleaved for sample 0 and 1. The two samples
292
+ // represent coarse and fine for WaveRNN.
293
+ template <int kInputBits, int kStateBits,
294
+ ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
295
+ bool kSplitGates = false>
296
+ inline void GruGatesAVXFixed(
297
+ int start, int end, int state_size, const int32_t* gru_recurrent_ptr,
298
+ const int32_t* input_ptr, const std::pair<float, float>* ar_sample01,
299
+ const float* ar_01_weights, int num_replicas, int replica_stride,
300
+ const float* ar_sample2, const float* ar_2_weights,
301
+ const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
302
+ // Convert the number of replicas from a variable to a template parameter
303
+ // with a switch. This enables the compiler to unroll the loop for
304
+ // the write, making it faster for common numbers of threads.
305
+ switch (num_replicas) {
306
+ case 1:
307
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/1,
308
+ kSplitGates>(
309
+ start, end, state_size, num_replicas, replica_stride,
310
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
311
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
312
+ break;
313
+ case 2:
314
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/2,
315
+ kSplitGates>(
316
+ start, end, state_size, num_replicas, replica_stride,
317
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
318
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
319
+ break;
320
+ case 4:
321
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/4,
322
+ kSplitGates>(
323
+ start, end, state_size, num_replicas, replica_stride,
324
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
325
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
326
+ break;
327
+ case 6:
328
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/6,
329
+ kSplitGates>(
330
+ start, end, state_size, num_replicas, replica_stride,
331
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
332
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
333
+ break;
334
+ default:
335
+ // Zero |kReplicas| tells the function to use the |num_replicas| variable.
336
+ GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/0,
337
+ kSplitGates>(
338
+ start, end, state_size, num_replicas, replica_stride,
339
+ gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
340
+ ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
341
+ }
342
+ }
343
+
344
+ #endif // __AVX2__
345
+
346
+ } // namespace csrblocksparse
347
+
348
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
sparse_matmul/compute/gru_gates_generic.h ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
19
+
20
+ #include "sparse_matmul/compute/ar_inputs.h"
21
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
22
+
23
+ namespace csrblocksparse {
24
+
25
+ constexpr int kGenericSIMDWidth = 4;
26
+
27
+ // TODO(b/188702959): Rename arguments to match gru_gates.h.
28
+ template <typename GRUStateType, typename GRUMatMulOutType, typename QR_W_Type,
29
+ typename SampleType, ARInputsMode kInputsMode,
30
+ bool SplitGates = false>
31
+ void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr,
32
+ const GRUMatMulOutType* gru_gates_ptr,
33
+ const GRUMatMulOutType* gru_gates_other_ptr,
34
+ const GRUMatMulOutType* conditioning_ptr,
35
+ GRUStateType* gru_h_ptr, const QR_W_Type* w_hat,
36
+ int proj_size, const SampleType* coarse_at_sminus1,
37
+ const SampleType* fine_at_sminus1,
38
+ const SampleType* coarse_at_s = nullptr) {
39
+ float qr_cell = 0.0f, reset, update, cell;
40
+ for (int i = start; i < end; ++i) {
41
+ if (kInputsMode == ARInputsMode::k0ARInputs) {
42
+ reset = static_cast<float>(gru_gates_ptr[i]);
43
+ update = static_cast<float>(gru_gates_ptr[proj_size + i]);
44
+ } else {
45
+ float qr_c_reset = static_cast<float>(qr_ptr[2 * i + 0]);
46
+ float qr_f_reset = static_cast<float>(qr_ptr[2 * i + 1]);
47
+ float qr_c_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 0]);
48
+ float qr_f_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 1]);
49
+ float qr_c_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 0]);
50
+ float qr_f_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 1]);
51
+ float w_hat_i_reset = 0.0f;
52
+ float w_hat_i_update = 0.0f;
53
+ float w_hat_i_cell = 0.0f;
54
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
55
+ w_hat_i_reset = static_cast<float>(w_hat[i]);
56
+ w_hat_i_update = static_cast<float>(w_hat[proj_size + i]);
57
+ w_hat_i_cell = static_cast<float>(w_hat[2 * proj_size + i]);
58
+ }
59
+ float coarse = static_cast<float>(coarse_at_sminus1[0]);
60
+ float fine = static_cast<float>(fine_at_sminus1[0]);
61
+ reset = qr_c_reset * coarse + qr_f_reset * fine;
62
+ update = qr_c_update * coarse + qr_f_update * fine;
63
+ qr_cell = qr_c_cell * coarse + qr_f_cell * fine;
64
+ if (kInputsMode == ARInputsMode::k3ARInputs) {
65
+ float coarse = static_cast<float>(coarse_at_s[0]);
66
+ reset += w_hat_i_reset * coarse;
67
+ update += w_hat_i_update * coarse;
68
+ qr_cell += w_hat_i_cell * coarse;
69
+ }
70
+ reset += static_cast<float>(gru_gates_ptr[i]);
71
+ update += static_cast<float>(gru_gates_ptr[proj_size + i]);
72
+ }
73
+ cell = static_cast<float>(gru_gates_ptr[2 * proj_size + i]);
74
+ if (SplitGates) {
75
+ reset += static_cast<float>(gru_gates_other_ptr[i]);
76
+ update += static_cast<float>(gru_gates_other_ptr[proj_size + i]);
77
+ cell += static_cast<float>(gru_gates_other_ptr[2 * proj_size + i]);
78
+ }
79
+ float reset_conditioning = static_cast<float>(conditioning_ptr[i]);
80
+ float update_conditioning =
81
+ static_cast<float>(conditioning_ptr[proj_size + i]);
82
+ float cell_conditioning =
83
+ static_cast<float>(conditioning_ptr[2 * proj_size + i]);
84
+ reset = fast_sigmoid(reset + reset_conditioning);
85
+ update = fast_sigmoid(update + update_conditioning);
86
+ float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning);
87
+ int h_index = i;
88
+ float prev_h = static_cast<float>(gru_h_ptr[h_index]);
89
+ float diff = prev_h - hbar;
90
+ float new_h = hbar + diff * update;
91
+ gru_h_ptr[h_index] = static_cast<GRUStateType>(new_h);
92
+ }
93
+ }
94
+
95
+ } // namespace csrblocksparse
96
+
97
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
sparse_matmul/compute/gru_gates_test.cc ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/gru_gates.h"
16
+
17
+ #include <cstdint>
18
+ #include <cstring>
19
+ #include <numeric>
20
+
21
+ #include "absl/memory/memory.h"
22
+ #include "absl/types/span.h"
23
+ #include "gmock/gmock.h"
24
+ #include "gtest/gtest.h"
25
+
26
+ namespace {
27
+
28
+ using csrblocksparse::ARInputsMode;
29
+
30
+ template <typename GRUStateType, typename InputType, typename SampleType = void,
31
+ csrblocksparse::ARInputsMode kInputsMode, bool kSplitGates>
32
+ csrblocksparse::CacheAlignedVector<GRUStateType> TestGruGates() {
33
+ using SampleWeightType = float;
34
+ constexpr int kStateSize = 16;
35
+ csrblocksparse::CacheAlignedVector<SampleWeightType> qr(6 * kStateSize);
36
+ csrblocksparse::CacheAlignedVector<SampleWeightType> w(3 * kStateSize);
37
+ csrblocksparse::CacheAlignedVector<InputType> gru_gates(3 * kStateSize);
38
+ csrblocksparse::CacheAlignedVector<InputType> gru_other_gates(3 * kStateSize);
39
+ csrblocksparse::CacheAlignedVector<InputType> conditioning(3 * kStateSize);
40
+ csrblocksparse::CacheAlignedVector<GRUStateType> gru_h(kStateSize);
41
+ csrblocksparse::GruGates<GRUStateType, InputType, SampleType> gru_gates_impl;
42
+ const SampleType kCoarseAtSMinus1(0.03f);
43
+ const SampleType kFineAtSMinus1(0.07f);
44
+ const SampleType kCoarseAtS(-0.02f);
45
+
46
+ qr.FillOnes();
47
+ w.FillOnes();
48
+ gru_gates.FillRandom();
49
+ gru_other_gates.FillRandom();
50
+ conditioning.FillRandom();
51
+ gru_h.FillZero();
52
+
53
+ gru_gates_impl.template GruWithARInput<kInputsMode, kSplitGates>(
54
+ /*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(),
55
+ conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1,
56
+ qr.data(),
57
+ /*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(),
58
+ gru_other_gates.data());
59
+ return gru_h;
60
+ }
61
+
62
+ TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) {
63
+ // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
64
+ // will also need to change.
65
+ const std::vector<float> kGoldenValues = {
66
+ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f,
67
+ 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f};
68
+ csrblocksparse::CacheAlignedVector<float> gru_h =
69
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
70
+ /*kSplitGates=*/true>();
71
+
72
+ ASSERT_EQ(kGoldenValues.size(), gru_h.size());
73
+ for (int i = 0; i < gru_h.size(); ++i) {
74
+ EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
75
+ }
76
+ }
77
+
78
+ TEST(GruGates, FloatWaveRNNFineMatchesGolden) {
79
+ // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
80
+ // will also need to change.
81
+ const std::vector<float> kGoldenValues = {
82
+ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f,
83
+ 0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f};
84
+ csrblocksparse::CacheAlignedVector<float> gru_h =
85
+ TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
86
+ /*kSplitGates=*/true>();
87
+
88
+ ASSERT_EQ(kGoldenValues.size(), gru_h.size());
89
+ for (int i = 0; i < gru_h.size(); ++i) {
90
+ EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
91
+ }
92
+ }
93
+
94
+ TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) {
95
+ // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
96
+ // will also need to change.
97
+ const std::vector<float> kGoldenValues = {
98
+ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f,
99
+ 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f};
100
+ csrblocksparse::CacheAlignedVector<float> gru_h =
101
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
102
+ /*kSplitGates=*/false>();
103
+
104
+ ASSERT_EQ(kGoldenValues.size(), gru_h.size());
105
+ for (int i = 0; i < gru_h.size(); ++i) {
106
+ EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
107
+ }
108
+ }
109
+
110
+ TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) {
111
+ using GRUMatMulOutType = csrblocksparse::fixed32<11>;
112
+ using GRUStateType = csrblocksparse::fixed16<2>;
113
+ using SampleType = csrblocksparse::fixed16<0>;
114
+ csrblocksparse::CacheAlignedVector<float> float_gru_h =
115
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
116
+ /*kSplitGates=*/true>();
117
+ csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
118
+ TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
119
+ ARInputsMode::k2ARInputs, /*kSplitGates=*/true>();
120
+
121
+ ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
122
+ for (int i = 0; i < fixed_gru_h.size(); ++i) {
123
+ EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
124
+ << "i=" << i;
125
+ }
126
+ }
127
+
128
+ TEST(GruGates, FixedWaveRNNFineMatchesFloat) {
129
+ using GRUMatMulOutType = csrblocksparse::fixed32<11>;
130
+ using GRUStateType = csrblocksparse::fixed16<2>;
131
+ using SampleType = csrblocksparse::fixed16<0>;
132
+ csrblocksparse::CacheAlignedVector<float> float_gru_h =
133
+ TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
134
+ /*kSplitGates=*/true>();
135
+ csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
136
+ TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
137
+ ARInputsMode::k3ARInputs, /*kSplitGates=*/true>();
138
+
139
+ ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
140
+ for (int i = 0; i < fixed_gru_h.size(); ++i) {
141
+ EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
142
+ << "i=" << i;
143
+ }
144
+ }
145
+
146
+ TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) {
147
+ using GRUMatMulOutType = csrblocksparse::fixed32<11>;
148
+ using GRUStateType = csrblocksparse::fixed16<2>;
149
+ using SampleType = csrblocksparse::fixed16<0>;
150
+ csrblocksparse::CacheAlignedVector<float> float_gru_h =
151
+ TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
152
+ /*kSplitGates=*/false>();
153
+ csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
154
+ TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
155
+ ARInputsMode::k2ARInputs, /*kSplitGates=*/false>();
156
+
157
+ ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
158
+ for (int i = 0; i < fixed_gru_h.size(); ++i) {
159
+ EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
160
+ << "i=" << i;
161
+ }
162
+ }
163
+
164
+ } // namespace
sparse_matmul/compute/kernels_arm.h ADDED
The diff for this file is too large to render. See raw diff
 
sparse_matmul/compute/kernels_avx.h ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
19
+
20
+ #if defined __AVX__
21
+ #include <immintrin.h>
22
+
23
+ #include <algorithm>
24
+ #include <type_traits>
25
+ // TODO(b/188702959): Remove fast_transcendentals with GRU refactor.
26
+ #include "sparse_matmul/numerics/fast_transcendentals.h"
27
+ #include "sparse_matmul/numerics/fixed_types.h"
28
+ #include "sparse_matmul/numerics/float16_types.h"
29
+ #include "sparse_matmul/numerics/type_utils.h"
30
+
31
+ namespace csrblocksparse {
32
+ namespace detail {
33
+
34
+ template <typename WeightType, typename RhsType, typename OutType>
35
+ struct IsAllowableFloatTypes
36
+ : std::integral_constant<bool, std::is_same<WeightType, float>::value &&
37
+ std::is_same<RhsType, float>::value &&
38
+ std::is_same<OutType, float>::value> {};
39
+
40
+ #if defined __AVX2__
41
+ // 16-bit inputs, 32-bit output exponent matches sum of input exponents
42
+ // OR
43
+ // 16-bit inputs, 16-bit output - will shift to match exponent
44
+ template <typename WeightType, typename RhsType, typename OutType>
45
+ struct IsAllowableFixedTypes
46
+ : std::integral_constant<bool, (IsFixed16Type<WeightType>::value &&
47
+ IsFixed16Type<RhsType>::value) &&
48
+ (IsFixed32Type<OutType>::value ||
49
+ IsFixed16Type<OutType>::value)> {};
50
+
51
+ template <typename WeightType, typename RhsType, typename OutType>
52
+ struct ShouldEnableGenericKernel
53
+ : std::integral_constant<
54
+ bool,
55
+ !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value &&
56
+ !IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {};
57
+
58
+ template <typename Type>
59
+ struct IsAddableFixedTypes
60
+ : std::integral_constant<bool, IsFixed32Type<Type>::value ||
61
+ IsFixed16Type<Type>::value> {};
62
+ template <typename Type>
63
+ struct ShouldEnableGenericAdd
64
+ : std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {};
65
+
66
+ #else // No AVX2.
67
+
68
+ template <typename WeightType, typename RhsType, typename OutType>
69
+ struct ShouldEnableGenericKernel
70
+ : std::integral_constant<
71
+ bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {};
72
+
73
+ template <typename Type>
74
+ struct ShouldEnableGenericAdd : std::true_type {};
75
+ #endif // __AVX2__
76
+
77
+ template <typename WeightType, typename RhsType, typename OutType>
78
+ struct ShouldEnableGenericSpMV_4x4
79
+ : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
80
+ template <typename WeightType, typename RhsType, typename OutType>
81
+ struct ShouldEnableGenericSpMM5_4x4
82
+ : ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
83
+ template <typename WeightType, typename RhsType, typename OutType>
84
+ struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
85
+ template <typename WeightType, typename RhsType, typename OutType>
86
+ struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
87
+
88
+ // The computational routines do NO error checking for speed. It is assumed
89
+ // that this has been handled by CSRBlockSparseMatrix.
90
+
91
+ // In-line function to extract results from a pair of registers and store in
92
+ // memory. Note that the non-const references are registers, and are modified
93
+ // by this function!
94
+ inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2,
95
+ float** out_ptr) {
96
+ // Horizontally add the results. We have 2 registers, |sum1| and |sum2| that
97
+ // each contain 2 sets of 4 values that need to be added.
98
+ sum1 = _mm256_hadd_ps(sum1, sum2);
99
+ sum1 = _mm256_hadd_ps(sum1, sum1);
100
+ // Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|,
101
+ // |res1|, |res3|]
102
+ if (relu) {
103
+ sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps());
104
+ }
105
+ // It is really hard in AVX to cross the 128 bit 'lanes' and this is the
106
+ // *only* way to do it.
107
+ // Get the top half of |sum1| in to bottom of |sum2|.
108
+ sum2 = _mm256_permute2f128_ps(sum1, sum1, 1);
109
+ // Interleave the values between the two registers.
110
+ sum1 = _mm256_unpacklo_ps(sum1, sum2);
111
+ // Save the lower 128 bits (4 floats).
112
+ __m128 result = _mm256_extractf128_ps(sum1, 0);
113
+ _mm_store_ps(*out_ptr, result);
114
+ *out_ptr += 4;
115
+ }
116
+
117
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
118
+ // blocked pattern, x is a vector and b is vector. Weights are stored for this
119
+ // routine by making each 4x4 block contiguous. Blocks are ordered in standard
120
+ // row-major format. column indices are converted to deltas and then multiplied
121
+ // by 2 to convert to bytes, so that the value can be used directly to offset
122
+ // the pointer into the rhs vector.
123
+ //
124
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
125
+ // this function. This is automatically taken care of in SparseLinearLayer.
126
+ // The bias is reconstructed through horizontal additions, leads to a small
127
+ // speedup by reducing latencies at the end of the loop.
128
+ template <typename WeightType, typename RhsType, typename OutType>
129
+ typename std::enable_if<std::is_same<WeightType, float>::value &&
130
+ std::is_same<RhsType, float>::value &&
131
+ std::is_same<OutType, float>::value>::type
132
+ SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
133
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
134
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
135
+ OutType* out_ptr, int64_t assigned_rows,
136
+ int64_t rows /* only used in SpMM variants */,
137
+ int64_t cols /* only used in SpMM variants */, int relu) {
138
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
139
+ // Broadcast the biases by 4 to undo the division by 4 in the input biases.
140
+ __m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
141
+ _mm_broadcast_ss(bias_ptr));
142
+ bias_ptr += 2;
143
+ __m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
144
+ _mm_broadcast_ss(bias_ptr));
145
+ bias_ptr += 2;
146
+
147
+ int reduced_col_count = *nnz_per_row++;
148
+ for (int c = 0; c < reduced_col_count; ++c) {
149
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
150
+ rhs_ptr += col_delta;
151
+ // Multiply this 4x4 block.
152
+ __m256 rhs =
153
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
154
+ __m256 weights1 = _mm256_load_ps(weights_ptr);
155
+ weights_ptr += 8;
156
+ sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs));
157
+ __m256 weights2 = _mm256_load_ps(weights_ptr);
158
+ weights_ptr += 8;
159
+ sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs));
160
+ }
161
+ Extract4Results(relu, sum1, sum2, &out_ptr);
162
+ }
163
+ }
164
+
165
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
166
+ // blocked pattern, x is a fat vector with 5 columns and b is vector. b is
167
+ // broadcast. Weights are stored for this routine by making each 4x4 block
168
+ // contiguous. Blocks are ordered in standard row-major format. column indices
169
+ // are converted to deltas and then multiplied by 2 to convert to bytes, so
170
+ // that the value can be used directly to offset the pointer into the rhs
171
+ // vector.
172
+ //
173
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
174
+ // this function. This is automatically taken care of in SparseLinearLayer.
175
+ // The bias is reconstructed through horizontal additions, leads to a small
176
+ // speedup by reducing latencies at the end of the loop.
177
+ template <typename WeightType, typename RhsType, typename OutType>
178
+ typename std::enable_if<std::is_same<WeightType, float>::value &&
179
+ std::is_same<RhsType, float>::value &&
180
+ std::is_same<OutType, float>::value>::type
181
+ SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
182
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
183
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
184
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
185
+ int relu) {
186
+ const RhsType* rhs_ptrs[5];
187
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
188
+
189
+ OutType* out_ptrs[5];
190
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
191
+
192
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
193
+ // We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|.
194
+ // Broadcast the biases by 4 to undo the division by 4 in the input biases.
195
+ __m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
196
+ _mm_broadcast_ss(bias_ptr));
197
+ bias_ptr += 2;
198
+ __m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
199
+ _mm_broadcast_ss(bias_ptr));
200
+ bias_ptr += 2;
201
+ __m256 sum1_1 = sum1_0;
202
+ __m256 sum2_1 = sum2_0;
203
+ __m256 sum1_2 = sum1_0;
204
+ __m256 sum2_2 = sum2_0;
205
+ __m256 sum1_3 = sum1_0;
206
+ __m256 sum2_3 = sum2_0;
207
+ __m256 sum1_4 = sum1_0;
208
+ __m256 sum2_4 = sum2_0;
209
+
210
+ int reduced_col_count = *nnz_per_row++;
211
+ for (int c = 0; c < reduced_col_count; ++c) {
212
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
213
+ for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
214
+
215
+ // Multiply this 4x4 block.
216
+ __m256 rhs =
217
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0]));
218
+ __m256 weights1 = _mm256_load_ps(weights_ptr);
219
+ weights_ptr += 8;
220
+ sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs));
221
+ __m256 weights2 = _mm256_load_ps(weights_ptr);
222
+ weights_ptr += 8;
223
+ sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs));
224
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1]));
225
+ sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs));
226
+ sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs));
227
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2]));
228
+ sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs));
229
+ sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs));
230
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3]));
231
+ sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs));
232
+ sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs));
233
+ rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4]));
234
+ sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs));
235
+ sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs));
236
+ }
237
+
238
+ Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]);
239
+ Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]);
240
+ Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]);
241
+ Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]);
242
+ Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]);
243
+ }
244
+ }
245
+
246
+ #ifdef __AVX2__
247
+
248
+ // In-line function to finish the computation of the result as 4x int32 in
249
+ // |sum|.
250
+ inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) {
251
+ // Horizontally add the results. We have 1 register that contains results
252
+ // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
253
+ // cross lanes, so we end up with [0 1 0 1 2 3 2 3]
254
+ sum = _mm256_hadd_epi32(sum, sum);
255
+ // Permutes the middle two pairs to get the answers together.
256
+ sum = _mm256_permute4x64_epi64(sum, 0xd8);
257
+ if (kShiftAmount > 0) {
258
+ // Shift right with rounding to get the right number of mantissa bits.
259
+ __m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1));
260
+ sum = _mm256_add_epi32(sum, rounding);
261
+ sum = _mm256_srai_epi32(sum, kShiftAmount);
262
+ }
263
+ // Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|,
264
+ // |res2|, |res3|]
265
+ if (relu) {
266
+ sum = _mm256_max_epi32(sum, _mm256_setzero_si256());
267
+ }
268
+ }
269
+
270
+ // In-line function to extract the 4x int32 results from |sum| to memory.
271
+ // Non-const reference for |sum| as it is a register.
272
+ inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum,
273
+ int32_t** out_ptr) {
274
+ Compute4Results(relu, kShiftAmount, sum);
275
+ // Save the lower 128 bits (4x int32).
276
+ __m128i result = _mm256_extractf128_si256(sum, 0);
277
+ _mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result);
278
+ *out_ptr += 4;
279
+ }
280
+
281
+ // In-line function to extract the 4x int32 results from sum to 4x int16 in
282
+ // memory.
283
+ // Non-const reference for |sum| as it is a register.
284
+ inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum,
285
+ int16_t** out_ptr) {
286
+ Compute4Results(relu, kShiftAmount, sum);
287
+ // Clip to 16 bit range (with saturation) and pack in the bottom 64 bits.
288
+ // Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64
289
+ // bits, replicated in the next 64 bits.
290
+ sum = _mm256_packs_epi32(sum, sum);
291
+ // Save 4x int 16 from the bottom 64 bits.
292
+ *reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0);
293
+ *out_ptr += 4;
294
+ }
295
+
296
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
297
+ // blocked pattern, x is a vector and b is vector. Weights are stored for this
298
+ // routine by making each 4x4 block contiguous. Blocks are ordered in standard
299
+ // row-major format. column indices are converted to deltas and then multiplied
300
+ // by 2 to convert to bytes, so that the value can be used directly to offset
301
+ // the pointer into the rhs vector.
302
+ //
303
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
304
+ // this function. This is automatically taken care of in SparseLinearLayer.
305
+ // The bias is reconstructed through horizontal additions, leads to a small
306
+ // speedup by reducing latencies at the end of the loop.
307
+ template <typename WeightType, typename RhsType, typename OutType>
308
+ typename std::enable_if<
309
+ IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
310
+ (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
311
+ SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
312
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
313
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
314
+ OutType* out_ptr, int64_t assigned_rows,
315
+ int64_t rows /* only used in SpMM variants */,
316
+ int64_t cols /* only used in SpMM variants */, int relu) {
317
+ constexpr int kShiftAmount =
318
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
319
+ OutType::kMantissaBits;
320
+ static_assert(kShiftAmount >= 0,
321
+ "Result must have fewer mantissa bits than product");
322
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
323
+ // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
324
+ __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
325
+ __m256i biases = _mm256_set_m128i(bias, bias);
326
+ bias_ptr += 4;
327
+ // Swap the top two pairs: [0 1 2 3 2 3 0 1]
328
+ // TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index
329
+ // register outside the row loop.
330
+ biases = _mm256_permute4x64_epi64(biases, 0xb4);
331
+ // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
332
+ biases = _mm256_unpacklo_epi32(biases, biases);
333
+ // Double the results to make up for the division by 4.
334
+ // TODO(b/188702959): consider moving this to where the biases are computed.
335
+ __m256i sum = _mm256_add_epi32(biases, biases);
336
+
337
+ // TODO(b/188702959): People don't like the old-fashioned, close-to-the-
338
+ // metal notation of *|nnz_per_row|++, so measure the effect of putting the
339
+ // increment in the for loop.
340
+ int reduced_col_count = *nnz_per_row;
341
+ ++nnz_per_row;
342
+ for (int c = 0; c < reduced_col_count; ++c) {
343
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
344
+ rhs_ptr += col_delta;
345
+ // Multiply this 4x4 block.
346
+ // Get the 4x int16 into the bottom of rhs_64.
347
+ __m128i rhs_64 =
348
+ _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr));
349
+ // Load all 16 weights.
350
+ __m256i weights =
351
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
352
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
353
+ // [0123 0123 0123 0123].
354
+ __m256i rhs = _mm256_broadcastq_epi64(rhs_64);
355
+ weights_ptr += 16;
356
+ // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
357
+ // adds adjacent pairs to make 8x32 bit results. Add these to the sum.
358
+ sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs));
359
+ }
360
+ static_assert(
361
+ IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
362
+ "AVX2 kernel only supports fixed16 and fixed32 types");
363
+ // The only significant difference between fixed16 and fixed32 is the size
364
+ // of the storage unit. The registers have to be repacked accordingly.
365
+ if (IsFixed32Type<OutType>::value) {
366
+ Extract4xint32(relu, kShiftAmount, sum,
367
+ reinterpret_cast<int32_t**>(&out_ptr));
368
+ } else {
369
+ Extract4xint16(relu, kShiftAmount, sum,
370
+ reinterpret_cast<int16_t**>(&out_ptr));
371
+ }
372
+ }
373
+ }
374
+
375
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
376
+ // blocked pattern, x is a fat vector with 5 columns and b is vector. b is
377
+ // broadcast. Weights are stored for this routine by making each 4x4 block
378
+ // contiguous. Blocks are ordered in standard row-major format. column indices
379
+ // are converted to deltas and then multiplied by 2 to convert to bytes, so
380
+ // that the value can be used directly to offset the pointer into the rhs
381
+ // vector.
382
+ //
383
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
384
+ // this function. This is automatically taken care of in SparseLinearLayer.
385
+ // The bias is reconstructed through horizontal additions, leads to a small
386
+ // speedup by reducing latencies at the end of the loop.
387
+ template <typename WeightType, typename RhsType, typename OutType>
388
+ typename std::enable_if<
389
+ IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
390
+ (IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
391
+ SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
392
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
393
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
394
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
395
+ int relu) {
396
+ constexpr int kShiftAmount =
397
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
398
+ OutType::kMantissaBits;
399
+ static_assert(kShiftAmount >= 0,
400
+ "Result must have fewer mantissa bits than product");
401
+ const RhsType* rhs_ptrs[5];
402
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
403
+
404
+ OutType* out_ptrs[5];
405
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
406
+
407
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
408
+ // We will acumulate the results in 5 registers, sum_0 to sum_4.
409
+ // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
410
+ __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
411
+ __m256i biases = _mm256_set_m128i(bias, bias);
412
+ bias_ptr += 4;
413
+ // Swap the top two pairs: [0 1 2 3 2 3 0 1]
414
+ biases = _mm256_permute4x64_epi64(biases, 0xb4);
415
+ // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
416
+ biases = _mm256_unpacklo_epi32(biases, biases);
417
+ // Double the results to make up for the division by 4.
418
+ __m256i sum_0 = _mm256_add_epi32(biases, biases);
419
+ __m256i sum_1 = sum_0;
420
+ __m256i sum_2 = sum_0;
421
+ __m256i sum_3 = sum_0;
422
+ __m256i sum_4 = sum_0;
423
+
424
+ int reduced_col_count = *nnz_per_row;
425
+ ++nnz_per_row;
426
+ for (int c = 0; c < reduced_col_count; ++c) {
427
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
428
+ for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
429
+ // Multiply this 4x4 block.
430
+ // Get the 4x int16 into the bottom of |rhs_64|.
431
+ __m128i rhs_64 =
432
+ _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0]));
433
+ // Load all 16 weights.
434
+ __m256i weights =
435
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
436
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
437
+ // [0123 0123 0123 0123].
438
+ __m256i rhs = _mm256_broadcastq_epi64(rhs_64);
439
+ weights_ptr += 16;
440
+ // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
441
+ // adds adjacent pairs to make 8x32 bit results. Add these to the sum.
442
+ sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs));
443
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1]));
444
+ rhs = _mm256_broadcastq_epi64(rhs_64);
445
+ sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs));
446
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2]));
447
+ rhs = _mm256_broadcastq_epi64(rhs_64);
448
+ sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs));
449
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3]));
450
+ rhs = _mm256_broadcastq_epi64(rhs_64);
451
+ sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs));
452
+ rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4]));
453
+ rhs = _mm256_broadcastq_epi64(rhs_64);
454
+ sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs));
455
+ }
456
+ static_assert(
457
+ IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
458
+ "AVX2 kernel only supports fixed16 and fixed32 types");
459
+ // The only significant difference between fixed16 and fixed32 is the size
460
+ // of the storage unit. The registers have to be repacked accordingly.
461
+ if (IsFixed32Type<OutType>::value) {
462
+ Extract4xint32(relu, kShiftAmount, sum_0,
463
+ reinterpret_cast<int32_t**>(&out_ptrs[0]));
464
+ Extract4xint32(relu, kShiftAmount, sum_1,
465
+ reinterpret_cast<int32_t**>(&out_ptrs[1]));
466
+ Extract4xint32(relu, kShiftAmount, sum_2,
467
+ reinterpret_cast<int32_t**>(&out_ptrs[2]));
468
+ Extract4xint32(relu, kShiftAmount, sum_3,
469
+ reinterpret_cast<int32_t**>(&out_ptrs[3]));
470
+ Extract4xint32(relu, kShiftAmount, sum_4,
471
+ reinterpret_cast<int32_t**>(&out_ptrs[4]));
472
+ } else {
473
+ Extract4xint16(relu, kShiftAmount, sum_0,
474
+ reinterpret_cast<int16_t**>(&out_ptrs[0]));
475
+ Extract4xint16(relu, kShiftAmount, sum_1,
476
+ reinterpret_cast<int16_t**>(&out_ptrs[1]));
477
+ Extract4xint16(relu, kShiftAmount, sum_2,
478
+ reinterpret_cast<int16_t**>(&out_ptrs[2]));
479
+ Extract4xint16(relu, kShiftAmount, sum_3,
480
+ reinterpret_cast<int16_t**>(&out_ptrs[3]));
481
+ Extract4xint16(relu, kShiftAmount, sum_4,
482
+ reinterpret_cast<int16_t**>(&out_ptrs[4]));
483
+ }
484
+ }
485
+ }
486
+
487
+ // Processes one GRU gate input with sigmoid.
488
+ template <int InputMantissaBits, int StateMantissaBits, bool SplitGates>
489
+ inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr,
490
+ const __m256i& input,
491
+ const int32_t* sigmoid_table) {
492
+ __m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr));
493
+ if (SplitGates) {
494
+ __m256i other =
495
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr));
496
+ gate = _mm256_add_epi32(gate, other);
497
+ }
498
+ gate = _mm256_add_epi32(gate, input);
499
+ // Compute sigmoids on reset and update.
500
+ return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits,
501
+ StateMantissaBits>(
502
+ sigmoid_table, gate);
503
+ }
504
+
505
+ // Processes the tanh and the final combination, returning the new GRU state.
506
+ template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false>
507
+ inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset,
508
+ const __m256i& update,
509
+ const __m256i& rounding_offset,
510
+ const void* gate_ptr, const void* gate_other_ptr,
511
+ const void* gru_h_ptr, const int32_t* tanh_table) {
512
+ // Multiply the cell GRU output and the reset. There is a slight danger of
513
+ // loss of precision here, so use 32x32=64 bit and shift back after.
514
+ __m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr));
515
+ if (SplitGates) {
516
+ __m256i other_gru =
517
+ _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr));
518
+ gru = _mm256_add_epi32(gru, other_gru);
519
+ }
520
+ // This only computes the products of the low-order 32 bits of each pair.
521
+ __m256i gru_lo = _mm256_mul_epi32(gru, reset);
522
+ // Swap odd and even 32-bit units and do it again to get the high products.
523
+ gru = _mm256_shuffle_epi32(gru, 0xb1);
524
+ __m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1));
525
+ // Now shift right to compensate for the multiply and re-interleave the
526
+ // 32-bit results.
527
+ // NOTE: There is no shift right arithmetic for 64 bit values until AVX512!
528
+ // Fortunately it doesn't matter, as the results are being truncated to 32
529
+ // bits and we aren't shifting right by more than 32 bits here.
530
+ gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits);
531
+ // The upper results are shifted LEFT, so we can use blend to recombine in
532
+ // a single instruction.
533
+ gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits);
534
+ // Recombine the 32 bit results from lo and hi, alternating.
535
+ gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa);
536
+ gru = _mm256_add_epi32(cell, gru);
537
+ // Compute tanh on the result. Although this instantly discards a bunch of
538
+ // bits, there were only 7 surplus bits for the multiply, which isn't enough
539
+ // to do it as 16x16=32.
540
+ __m256i hbar =
541
+ csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits,
542
+ StateMantissaBits>(tanh_table, gru);
543
+ // Load the 16-bit previous GRU state and sign-extend to 32 bits.
544
+ gru = _mm256_cvtepi16_epi32(
545
+ _mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr)));
546
+ gru = _mm256_sub_epi32(gru, hbar);
547
+ // Since |gru| is 16 bit sign-extended to 32, and |update| is the output of
548
+ // sigmoid, it is always contained within 16 bits and never negative, we can
549
+ // use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the
550
+ // addend will always be zero, and this is twice as fast as full blown
551
+ // 32x32=32. The only possible problem is if the subtract above caused
552
+ // overflow.
553
+ gru = _mm256_madd_epi16(gru, update);
554
+ // Renormalize to fixed16. This time rounding is critical, as this is the
555
+ // output GRU state.
556
+ gru = _mm256_add_epi32(gru, rounding_offset);
557
+ gru = _mm256_srai_epi32(gru, StateMantissaBits);
558
+ return _mm256_add_epi32(gru, hbar);
559
+ }
560
+
561
+ template <typename Type>
562
+ typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors(
563
+ int start, int end, const Type* add1, const Type* add2, Type* result) {
564
+ constexpr int kSIMDWidth = 8;
565
+ for (int i = start; i < end; i += kSIMDWidth) {
566
+ __m256i data1 =
567
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
568
+ __m256i data2 =
569
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
570
+ data1 = _mm256_add_epi32(data1, data2);
571
+ _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
572
+ }
573
+ }
574
+
575
+ template <typename Type>
576
+ typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors(
577
+ int start, int end, const Type* add1, const Type* add2, Type* result) {
578
+ constexpr int kSIMDWidth = 16;
579
+ for (int i = start; i < end; i += kSIMDWidth) {
580
+ __m256i data1 =
581
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
582
+ __m256i data2 =
583
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
584
+ data1 = _mm256_add_epi16(data1, data2);
585
+ _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
586
+ }
587
+ }
588
+
589
+ #endif // __AVX2__
590
+
591
+ } // namespace detail
592
+ } // namespace csrblocksparse
593
+
594
+ #undef LABEL_COL_LOOP
595
+ #undef LABEL_ROW_LOOP
596
+ #undef LABEL_SKIP_COL_LOOP
597
+ #undef LABEL_TOP_LOOP
598
+
599
+ #endif // __AVX__
600
+
601
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
sparse_matmul/compute/kernels_generic.h ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
19
+
20
+ #include <algorithm>
21
+ #include <type_traits>
22
+
23
+ #include "sparse_matmul/numerics/fixed_types.h"
24
+ #include "sparse_matmul/numerics/float16_types.h"
25
+ #include "sparse_matmul/numerics/type_utils.h"
26
+
27
+ // Separate out the assembly kernels for readability. Eventually this will
28
+ // become an ifdef switch on the architecture type.
29
+ #if defined __aarch64__
30
+ #include "sparse_matmul/compute/kernels_arm.h"
31
+ #elif defined __AVX__
32
+ #include "sparse_matmul/compute/kernels_avx.h"
33
+ #else // defined __AVX__
34
+ // If there is no architecture-specific implementation, then always use generic.
35
+ template <typename WeightType, typename RhsType, typename OutType>
36
+ struct ShouldEnableGenericSpMV_4x4 : std::true_type {};
37
+ template <typename WeightType, typename RhsType, typename OutType>
38
+ struct ShouldEnableGenericSpMM5_4x4 : std::true_type {};
39
+ template <typename WeightType, typename RhsType, typename OutType>
40
+ struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
41
+ template <typename WeightType, typename RhsType, typename OutType>
42
+ struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
43
+ template <typename Type>
44
+ struct ShouldEnableGenericAdd : std::true_type {};
45
+ #endif // defined __arch64__
46
+
47
+ namespace csrblocksparse {
48
+ namespace detail {
49
+
50
+ // The computational routines do NO error checking for speed. It is assumed
51
+ // that this has been handled by CSRBlockSparseMatrix.
52
+
53
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
54
+ // blocked pattern, x is a vector and b is vector. Weights are stored for this
55
+ // routine by making each 4x4 block contiguous. Blocks are ordered in standard
56
+ // row-major format. column indices are converted to deltas and then multiplied
57
+ // by 2 to convert to bytes, so that the value can be used directly to offset
58
+ // the pointer into the rhs vector.
59
+ //
60
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
61
+ // this function. This is automatically taken care of in SparseLinearLayer.
62
+ // The bias is reconstructed through horizontal additions, leads to a small
63
+ // speedup by reducing latencies at the end of the loop.
64
+ template <typename WeightType, typename RhsType, typename OutType>
65
+ typename std::enable_if<
66
+ ShouldEnableGenericSpMV_4x4<WeightType, RhsType, OutType>::value>::type
67
+ SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
68
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
69
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
70
+ OutType* out_ptr, int64_t assigned_rows,
71
+ int64_t rows /* only used in SpMM variants */,
72
+ int64_t cols /* only used in SpMM variants */, int relu) {
73
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
74
+ float accumulators[4];
75
+ // Undo the divion by the happens for the assembly version.
76
+ for (int i = 0; i < 4; ++i)
77
+ accumulators[i] = 4.f * static_cast<float>(*bias_ptr++);
78
+
79
+ int reduced_col_count = *nnz_per_row++;
80
+ for (int c = 0; c < reduced_col_count; ++c) {
81
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
82
+ rhs_ptr += col_delta;
83
+
84
+ // Multiply this 4x4 block.
85
+ for (int i = 0; i < 4; ++i) {
86
+ for (int j = 0; j < 4; ++j) {
87
+ accumulators[i] += static_cast<float>(*weights_ptr++) *
88
+ static_cast<float>(rhs_ptr[j]);
89
+ }
90
+ }
91
+ }
92
+
93
+ for (int i = 0; i < 4; ++i)
94
+ *out_ptr++ = static_cast<OutType>(relu ? std::max(accumulators[i], 0.f)
95
+ : accumulators[i]);
96
+ }
97
+ }
98
+
99
+ // Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
100
+ // blocked pattern, x is a fat vector with 5 columns and b is vector. b is
101
+ // broadcast. Weights are stored for this routine by making each 4x4 block
102
+ // contiguous. Blocks are ordered in standard row-major format. column indices
103
+ // are converted to deltas and then multiplied by 2 to convert to bytes, so
104
+ // that the value can be used directly to offset the pointer into the rhs
105
+ // vector.
106
+ //
107
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
108
+ // this function. This is automatically taken care of in SparseLinearLayer.
109
+ // The bias is reconstructed through horizontal additions, leads to a small
110
+ // speedup by reducing latencies at the end of the loop.
111
+ template <typename WeightType, typename RhsType, typename OutType>
112
+ typename std::enable_if<
113
+ ShouldEnableGenericSpMM5_4x4<WeightType, RhsType, OutType>::value>::type
114
+ SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
115
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
116
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
117
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
118
+ int relu) {
119
+ const RhsType* rhs_ptrs[5];
120
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
121
+
122
+ OutType* out_ptrs[5];
123
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
124
+
125
+ for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
126
+ float accumulators[4][5];
127
+ // Undo the divion by the happens for the assembly version.
128
+ for (int i = 0; i < 4; ++i) {
129
+ for (int k = 0; k < 5; ++k) {
130
+ accumulators[i][k] = 4.f * static_cast<float>(*bias_ptr);
131
+ }
132
+ ++bias_ptr;
133
+ }
134
+
135
+ int reduced_col_count = *nnz_per_row++;
136
+ for (int c = 0; c < reduced_col_count; ++c) {
137
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
138
+ for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
139
+
140
+ // multiply this 4x4 block
141
+ for (int i = 0; i < 4; ++i) {
142
+ for (int j = 0; j < 4; ++j) {
143
+ for (int k = 0; k < 5; ++k) {
144
+ accumulators[i][k] += static_cast<float>(*weights_ptr) *
145
+ static_cast<float>(rhs_ptrs[k][j]);
146
+ }
147
+ weights_ptr++;
148
+ }
149
+ }
150
+ }
151
+
152
+ for (int k = 0; k < 5; ++k) {
153
+ for (int i = 0; i < 4; ++i) {
154
+ out_ptrs[k][0] = static_cast<OutType>(
155
+ relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]);
156
+ out_ptrs[k]++;
157
+ }
158
+ }
159
+ }
160
+ }
161
+
162
+ // Performs the calculation y = A * x + b where A is a sparse matrix with
163
+ // a 1x1 blocked pattern (ie unstructured), x is a
164
+ // vector and b is vector.
165
+ // Weights are stored for this routine in standard CSR format. Each row must
166
+ // have a multiple of 8 columns.
167
+ // column indices are converted to deltas and then multiplied by 2 to convert
168
+ // to bytes, so that the value can be used directly to offset the pointer
169
+ // into the rhs vector.
170
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
171
+ // this function. This is automatically taken care of in SparseLinearLayer.
172
+ // The bias is reconstructed through horizontal additions, leads to a small
173
+ // speedup by reducing latencies at the end of the loop.
174
+ template <typename WeightType, typename RhsType, typename OutType>
175
+ typename std::enable_if<
176
+ ShouldEnableGenericSpMV_1x1<WeightType, RhsType, OutType>::value>::type
177
+ SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
178
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
179
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
180
+ OutType* out_ptr, int64_t assigned_rows,
181
+ int64_t rows /* only used in SpMM variants */,
182
+ int64_t cols /* only used in SpMM variants */, int relu) {
183
+ for (int row = 0; row < assigned_rows; ++row) {
184
+ // Undo the divion by the happens for the assembly version.
185
+ float accumulator = 4.f * static_cast<float>(*bias_ptr++);
186
+
187
+ int col_count = *nnz_per_row++;
188
+ for (int c = 0; c < col_count; ++c) {
189
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
190
+ rhs_ptr += col_delta;
191
+
192
+ accumulator +=
193
+ static_cast<float>(*weights_ptr++) * static_cast<float>(*rhs_ptr);
194
+ }
195
+
196
+ *out_ptr++ =
197
+ static_cast<OutType>(relu ? std::max(accumulator, 0.f) : accumulator);
198
+ }
199
+ }
200
+
201
+ // Performs the calculation y = A * x + b where A is a sparse matrix with
202
+ // a 1x1 blocked pattern (ie unstructured), x is a
203
+ // vector and b is vector.
204
+ // Weights are stored for this routine in standard CSR format. Each row must
205
+ // have a multiple of 8 columns.
206
+ // column indices are converted to deltas and then multiplied by 2 to convert
207
+ // to bytes, so that the value can be used directly to offset the pointer
208
+ // into the rhs vector.
209
+ // NOTE: The bias is expected to have be multiplied by .25f prior to calling
210
+ // this function. This is automatically taken care of in SparseLinearLayer.
211
+ // The bias is reconstructed through horizontal additions, leads to a small
212
+ // speedup by reducing latencies at the end of the loop.
213
+ template <typename WeightType, typename RhsType, typename OutType>
214
+ typename std::enable_if<
215
+ ShouldEnableGenericSpMM5_1x1<WeightType, RhsType, OutType>::value>::type
216
+ SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
217
+ const int32_t* nnz_per_row, const RhsType* rhs_ptr,
218
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
219
+ OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
220
+ int relu) {
221
+ const RhsType* rhs_ptrs[5];
222
+ for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
223
+
224
+ OutType* out_ptrs[5];
225
+ for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
226
+
227
+ for (int row = 0; row < assigned_rows; ++row) {
228
+ // Undo the divion by the happens for the assembly version.
229
+ float accumulator[5];
230
+ for (int i = 0; i < 5; ++i)
231
+ accumulator[i] = 4.f * static_cast<float>(*bias_ptr);
232
+
233
+ ++bias_ptr;
234
+
235
+ int col_count = *nnz_per_row++;
236
+ for (int c = 0; c < col_count; ++c) {
237
+ int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
238
+ for (int i = 0; i < 5; ++i) {
239
+ rhs_ptrs[i] += col_delta;
240
+ accumulator[i] += static_cast<float>(*weights_ptr) *
241
+ static_cast<float>(rhs_ptrs[i][0]);
242
+ }
243
+ weights_ptr++;
244
+ }
245
+
246
+ for (int i = 0; i < 5; ++i) {
247
+ out_ptrs[i][0] = static_cast<OutType>(relu ? std::max(accumulator[i], 0.f)
248
+ : accumulator[i]);
249
+ out_ptrs[i]++;
250
+ }
251
+ }
252
+ }
253
+
254
+ template <typename Type>
255
+ typename std::enable_if<ShouldEnableGenericAdd<Type>::value>::type SumVectors(
256
+ int start, int end, const Type* add1, const Type* add2, Type* result) {
257
+ LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!";
258
+ for (int i = start; i < end; ++i) {
259
+ Type sum = static_cast<Type>(static_cast<float>(add1[i]) +
260
+ static_cast<float>(add2[i]));
261
+ result[i] = sum;
262
+ }
263
+ }
264
+
265
+ } // namespace detail
266
+ } // namespace csrblocksparse
267
+
268
+ #undef LABEL_COL_LOOP
269
+ #undef LABEL_ROW_LOOP
270
+ #undef LABEL_SKIP_COL_LOOP
271
+ #undef LABEL_TOP_LOOP
272
+
273
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
sparse_matmul/compute/matmul.h ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
19
+
20
+ #include <cstdint>
21
+ #include <vector>
22
+
23
+ #include "absl/time/time.h"
24
+ #include "sparse_matmul/compute/matmul_fixed_avx2.h"
25
+ #include "sparse_matmul/compute/matmul_generic.h"
26
+ #include "sparse_matmul/numerics/fixed_types.h"
27
+ #include "sparse_matmul/numerics/type_utils.h"
28
+ #if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
29
+ #include <cpuid.h>
30
+ #endif
31
+
32
+ namespace csrblocksparse {
33
+
34
+ // The number of elements in a block.
35
+ constexpr int kBlockSize = 4;
36
+
37
+ // Base class for Matmul containing the members that are non type-specicfic.
38
+ class MatmulBase {
39
+ public:
40
+ // Constructor initializes the flags that determine which implementation to
41
+ // use at run-time, constrained by both compiler flags and cpuid.
42
+ MatmulBase() {
43
+ #if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
44
+ // Code tested to work on Linux systems and multiple Android emulators.
45
+ unsigned int eax, ebx, ecx, edx;
46
+ if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) {
47
+ using_avx_ = (ecx & bit_AVX) != 0;
48
+ if (using_avx_) {
49
+ __get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx);
50
+ using_avx2_ = (ebx & bit_AVX2) != 0;
51
+ using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) &&
52
+ (ebx & bit_AVX512BW) != 0;
53
+ VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_;
54
+ } else {
55
+ LOG(ERROR) << "AVX not found at all!";
56
+ }
57
+ }
58
+ #else
59
+ using_aarch64_ = true;
60
+ #endif
61
+ }
62
+
63
+ protected:
64
+ // Flags that define what (runtime) architectures are available. Flags that
65
+ // are set are limited by both the compiler flags and runtime environment.
66
+ bool using_avx512_ = false;
67
+ bool using_avx2_ = false;
68
+ bool using_avx_ = false;
69
+ bool using_aarch64_ = false;
70
+ };
71
+
72
+ // The master template is really a catch-all for the unimplmented cases to
73
+ // report an error.
74
+ template <typename WeightType, typename RhsType>
75
+ class Matmul : public MatmulBase {
76
+ public:
77
+ // Sparse inputs, outputs replicated strided for each thread.
78
+ template <typename OutType>
79
+ void MatVec4x4(const WeightType* weights, const RhsType* rhs,
80
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias,
81
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
82
+ int start_row, int end_row, bool relu, int replicas,
83
+ int stride, OutType* output) {
84
+ // The specializations should take care of every real case.
85
+ CHECK(false) << "Unsupported combination of types used!";
86
+ }
87
+ template <typename OutType>
88
+ void MatVec8x4(const WeightType* weights, const RhsType* rhs,
89
+ const typename TypeOfProduct<WeightType, RhsType>::type* bias,
90
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
91
+ int start_row, int end_row, bool relu, int replicas,
92
+ int stride, OutType* output) {
93
+ // The specializations should take care of every real case.
94
+ CHECK(false) << "Unsupported combination of types used!";
95
+ }
96
+ };
97
+
98
+ // Full specialization for float.
99
+ template <>
100
+ class Matmul<float, float> : public MatmulBase {
101
+ public:
102
+ void MatVec4x4(const float* weights, const float* rhs, const float* bias,
103
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
104
+ int start_row, int end_row, bool relu, int replicas,
105
+ int stride, float* output) {
106
+ detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
107
+ start_row, end_row, /*block_height=*/4,
108
+ /*block_width=*/4, relu, replicas, stride,
109
+ output);
110
+ }
111
+ void MatVec8x4(const float* weights, const float* rhs, const float* bias,
112
+ const int32_t* nnz_per_row, const int16_t* rhs_indices,
113
+ int start_row, int end_row, bool relu, int replicas,
114
+ int stride, float* output) {
115
+ detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
116
+ start_row, end_row, /*block_height=*/8,
117
+ /*block_width=*/4, relu, replicas, stride,
118
+ output);
119
+ }
120
+ };
121
+
122
+ // Partial specialization for fixed types. Covers fixed16xfixed16 = OutType,
123
+ // where OutType should be fixed16 or fixed32. The mantissa bits don't have
124
+ // to match.
125
+ template <int WeightBits, int RhsBits>
126
+ class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase {
127
+ public:
128
+ using WeightType = fixed16<WeightBits>;
129
+ using RhsType = fixed16<RhsBits>;
130
+
131
+ template <typename OutType>
132
+ void MatVec4x4(const int16_t* weights, const int16_t* rhs,
133
+ const int32_t* bias, const int32_t* nnz_per_row,
134
+ const int16_t* rhs_indices, int start_row, int end_row,
135
+ bool relu, int replicas, int stride, OutType* output) {
136
+ constexpr int kShiftAmount =
137
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
138
+ OutType::kMantissaBits;
139
+ static_assert(kShiftAmount >= 0,
140
+ "OutType must not have more mantissa bits than inputs");
141
+ #if defined __AVX2__
142
+ CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
143
+ if (sizeof(*output) == 4) {
144
+ int32_t* out32 = reinterpret_cast<int32_t*>(output);
145
+ detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
146
+ start_row, end_row, relu, kShiftAmount,
147
+ replicas, stride, out32);
148
+ } else {
149
+ int16_t* out16 = reinterpret_cast<int16_t*>(output);
150
+ detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
151
+ start_row, end_row, relu, kShiftAmount,
152
+ replicas, stride, out16);
153
+ }
154
+ #elif defined __aarch64__
155
+ if (using_aarch64_) {
156
+ LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!";
157
+ }
158
+
159
+ #else
160
+ detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
161
+ start_row, end_row, /*block_height=*/4,
162
+ /*block_width=*/4, relu, sizeof(*output),
163
+ kShiftAmount, replicas, stride, output);
164
+ #endif // __AVX2__
165
+ }
166
+
167
+ template <typename OutType>
168
+ void MatVec8x4(const int16_t* weights, const int16_t* rhs,
169
+ const int32_t* bias, const int32_t* nnz_per_row,
170
+ const int16_t* rhs_indices, int start_row, int end_row,
171
+ bool relu, int replicas, int stride, OutType* output) {
172
+ constexpr int kShiftAmount =
173
+ TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
174
+ OutType::kMantissaBits;
175
+ static_assert(kShiftAmount >= 0,
176
+ "OutType must not have more mantissa bits than inputs");
177
+ #if defined __AVX2__
178
+ CHECK(replicas == 1 && sizeof(*output) == 4)
179
+ << "Only replicas == 1 and fixed32 output are implemented for AVX2!";
180
+ CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
181
+ int32_t* out32 = reinterpret_cast<int32_t*>(output);
182
+ detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
183
+ start_row, end_row, relu, kShiftAmount, out32);
184
+ #elif defined __aarch64__
185
+ if (using_aarch64_) {
186
+ LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!";
187
+ }
188
+ #else
189
+ detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
190
+ start_row, end_row, /*block_height=*/8,
191
+ /*block_width=*/4, relu, sizeof(*output),
192
+ kShiftAmount, replicas, stride, output);
193
+ #endif // __AVX2__
194
+ }
195
+ };
196
+
197
+ } // namespace csrblocksparse
198
+
199
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
sparse_matmul/compute/matmul_fixed_avx2.cc ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/matmul_fixed_avx2.h"
16
+
17
+ #include <cstdint>
18
+
19
+ #if defined __AVX__
20
+ #include <immintrin.h>
21
+ #endif
22
+
23
+ #include "sparse_matmul/compute/matmul.h"
24
+
25
+ namespace csrblocksparse {
26
+ namespace detail {
27
+
28
+ static const int32_t kint32min = static_cast<int32_t>(~0x7FFFFFFF);
29
+ static const int32_t kint32max = static_cast<int32_t>(0x7FFFFFFF);
30
+
31
+ #if defined __AVX2__
32
+ // In-line function computes and returns the result of one row (of blocks) as
33
+ // 4x int32_t. |weights_ptr| is a non-const reference so it can easily be
34
+ // interpreted as belonging to the caller.
35
+ inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs,
36
+ const int16_t* rhs_indices, int nnz,
37
+ int16_t const*& weights_ptr) {
38
+ // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
39
+ // Zero and 0-3 are the 4x32 bit bias values.
40
+ __m256i sum = _mm256_cvtepu32_epi64(bias128);
41
+
42
+ for (int c = 0; c < nnz; ++c) {
43
+ int rhs_index = rhs_indices[c];
44
+ // Load all 16 weights.
45
+ __m256i weights =
46
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
47
+ // Get the 4x int16_t into the bottom of |rhs_64|.
48
+ __m128i rhs_64 = _mm_loadl_epi64(
49
+ reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
50
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
51
+ // [0123 0123 0123 0123].
52
+ __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
53
+ weights_ptr += 16;
54
+ sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value));
55
+ }
56
+ // Horizontally add the results. We have 1 register that contains results
57
+ // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
58
+ // cross lanes, so we end up with [0 1 0 1 2 3 2 3]
59
+ sum = _mm256_hadd_epi32(sum, sum);
60
+ // Permutes the middle two pairs to get the answers together.
61
+ return _mm256_permute4x64_epi64(sum, 0xd8);
62
+ }
63
+
64
+ // Template that allows any fixed combination of OutType and replicas, plus
65
+ // variable |relu|, |shift_out|. Note that |kReplicas| is a template arg as
66
+ // well as a function arg so we can hard-code a limited amount of unrolling.
67
+ template <typename OutType, int kReplicas>
68
+ void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs,
69
+ const int32_t* bias, const int32_t* nnz_per_row,
70
+ const int16_t* rhs_indices, int start_row,
71
+ int end_row, bool relu, int shift_out,
72
+ int replicas, int stride, OutType* output) {
73
+ int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
74
+ __m256i rounding = _mm256_set1_epi32(rounding_addon);
75
+ __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
76
+ for (int row_block = start_row; row_block < end_row; ++row_block) {
77
+ // Load 4 biases [0 1 2 3].
78
+ __m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias));
79
+ bias += kBlockSize;
80
+ int nnz = nnz_per_row[row_block];
81
+ __m256i sum =
82
+ ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr);
83
+ rhs_indices += nnz;
84
+ // Shift right with rounding to get the right number of mantissa bits.
85
+ sum = _mm256_add_epi32(sum, rounding);
86
+ sum = _mm256_srai_epi32(sum, shift_out);
87
+ // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
88
+ sum = _mm256_max_epi32(sum, zero);
89
+ if (sizeof(OutType) == 2) {
90
+ // Clip to 16 bit range (with saturation) and pack in the bottom 64
91
+ // bits. The 64 bit result is replicated across the whole 256 bit
92
+ // register. [0123 0123 0123 0123]
93
+ sum = _mm256_packs_epi32(sum, sum);
94
+ int64_t result = _mm256_extract_epi64(sum, 0);
95
+ *reinterpret_cast<int64_t*>(output) = result;
96
+ if (kReplicas > 1) {
97
+ *reinterpret_cast<int64_t*>(output + stride) = result;
98
+ if (kReplicas > 2) {
99
+ for (int r = 2; r < replicas; ++r) {
100
+ *reinterpret_cast<int64_t*>(output + r * stride) = result;
101
+ }
102
+ }
103
+ }
104
+ } else {
105
+ // Save the lower 128 bits (4x int32_t).
106
+ __m128i result = _mm256_extractf128_si256(sum, 0);
107
+ _mm_store_si128(reinterpret_cast<__m128i*>(output), result);
108
+ if (kReplicas > 1) {
109
+ _mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result);
110
+ if (kReplicas > 2) {
111
+ for (int r = 2; r < replicas; ++r) {
112
+ _mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride),
113
+ result);
114
+ }
115
+ }
116
+ }
117
+ }
118
+ output += kBlockSize;
119
+ }
120
+ }
121
+
122
+ // Version that covers all possible combinations of the variable conditions:
123
+ // |relu|, |shift_out|, |replicas|, with int16_t |output|.
124
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
125
+ const int32_t* bias, const int32_t* nnz_per_row,
126
+ const int16_t* rhs_indices, int start_row, int end_row,
127
+ bool relu, int shift_out, int replicas, int stride,
128
+ int16_t* output) {
129
+ if (replicas <= 1) {
130
+ MatVec4x4FixedAVX2Template<int16_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
131
+ rhs_indices, start_row, end_row,
132
+ relu, shift_out, 1, stride, output);
133
+ } else if (replicas == 2) {
134
+ MatVec4x4FixedAVX2Template<int16_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
135
+ rhs_indices, start_row, end_row,
136
+ relu, shift_out, 2, stride, output);
137
+ } else {
138
+ MatVec4x4FixedAVX2Template<int16_t, 3>(
139
+ weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
140
+ relu, shift_out, replicas, stride, output);
141
+ }
142
+ }
143
+
144
+ // Version that covers all possible combinations of the variable conditions:
145
+ // |relu|, |shift_out|, |replicas|, with int32_t |output|.
146
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
147
+ const int32_t* bias, const int32_t* nnz_per_row,
148
+ const int16_t* rhs_indices, int start_row, int end_row,
149
+ bool relu, int shift_out, int replicas, int stride,
150
+ int32_t* output) {
151
+ if (replicas <= 1) {
152
+ MatVec4x4FixedAVX2Template<int32_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
153
+ rhs_indices, start_row, end_row,
154
+ relu, shift_out, 1, stride, output);
155
+ } else if (replicas == 2) {
156
+ MatVec4x4FixedAVX2Template<int32_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
157
+ rhs_indices, start_row, end_row,
158
+ relu, shift_out, 2, stride, output);
159
+ } else {
160
+ MatVec4x4FixedAVX2Template<int32_t, 3>(
161
+ weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
162
+ relu, shift_out, replicas, stride, output);
163
+ }
164
+ }
165
+
166
+ // In-line function computes and returns the result of one row (of blocks) as
167
+ // 8x int32_t. weights_ptr is a non-const reference so it can easily be
168
+ // interpreted as belonging to the caller.
169
+ inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs,
170
+ const int16_t* rhs_indices, int nnz,
171
+ int16_t const*& weights_ptr) {
172
+ // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
173
+ // Zero and 0-3 are the 4x32 bit bias values from 128 bit half of the input.
174
+ __m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256));
175
+ // Plus 4 more in another sum register from the upper 128 bit half.
176
+ __m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1));
177
+
178
+ for (int c = 0; c < nnz; ++c) {
179
+ int rhs_index = rhs_indices[c];
180
+ // Load all 16 weights.
181
+ __m256i weights =
182
+ _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
183
+ // Get the 4x int16_t into the bottom of |rhs_64|.
184
+ __m128i rhs_64 = _mm_loadl_epi64(
185
+ reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
186
+ // Broadcast the rhs, pretending that each is a 64-bit unit:
187
+ // [0123 0123 0123 0123].
188
+ __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
189
+ weights_ptr += 16;
190
+ sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value));
191
+ // Same again for the other 4 results, re-using the same rhs value.
192
+ weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
193
+ weights_ptr += 16;
194
+ sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value));
195
+ }
196
+ // Horizontally add the results. We have 2 registers that contain results
197
+ // [0 0 1 1 2 2 3 3], and [4 4 5 5 6 6 7 7] but hadd (and almost no other AVX
198
+ // instruction) will not cross lanes, so we end up with [0 1 4 5 2 3 6 7]
199
+ sum1 = _mm256_hadd_epi32(sum1, sum2);
200
+ // Permutes the middle two pairs to get the answers in the right order.
201
+ return _mm256_permute4x64_epi64(sum1, 0xd8);
202
+ }
203
+
204
+ // Version that covers the main conditions used with 8x4:
205
+ // |relu|, |shift_out|, with int32_t |output|.
206
+ void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
207
+ const int32_t* bias, const int32_t* nnz_per_row,
208
+ const int16_t* rhs_indices, int start_row, int end_row,
209
+ bool relu, int shift_out, int32_t* output) {
210
+ int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
211
+ __m256i rounding = _mm256_set1_epi32(rounding_addon);
212
+ __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
213
+ for (int row_block = start_row; row_block < end_row; ++row_block) {
214
+ // Load 4 biases [0 1 2 3 4 5 6 7].
215
+ __m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias));
216
+ bias += kBlockSize * 2;
217
+ int nnz = nnz_per_row[row_block];
218
+ __m256i sum =
219
+ Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr);
220
+ rhs_indices += nnz;
221
+ // Shift right with rounding to get the right number of mantissa bits.
222
+ sum = _mm256_add_epi32(sum, rounding);
223
+ sum = _mm256_srai_epi32(sum, shift_out);
224
+ // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
225
+ sum = _mm256_max_epi32(sum, zero);
226
+ // Save the all 256 bits (8x int32_t).
227
+ _mm256_store_si256(reinterpret_cast<__m256i*>(output), sum);
228
+ output += kBlockSize * 2;
229
+ }
230
+ }
231
+
232
+ #endif
233
+
234
+ } // namespace detail
235
+ } // namespace csrblocksparse
sparse_matmul/compute/matmul_fixed_avx2.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
19
+
20
+ #include <cstdint>
21
+
22
+ namespace csrblocksparse {
23
+ namespace detail {
24
+
25
+ // Version that covers all possible combinations of the variable conditions:
26
+ // |relu|, |shift_out|, |replicas|, with int16 output.
27
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
28
+ const int32_t* bias, const int32_t* nnz_per_row,
29
+ const int16_t* rhs_indices, int start_row, int end_row,
30
+ bool relu, int shift_out, int replicas, int stride,
31
+ int16_t* output);
32
+ // Version that covers all possible combinations of the variable conditions:
33
+ // |relu|, |shift_out|, |replicas|, with int32 output.
34
+ void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
35
+ const int32_t* bias, const int32_t* nnz_per_row,
36
+ const int16_t* rhs_indices, int start_row, int end_row,
37
+ bool relu, int shift_out, int replicas, int stride,
38
+ int32_t* output);
39
+ // Version that covers the main conditions used with 8x4:
40
+ // |relu|, |shift_out|, with int32 output.
41
+ void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
42
+ const int32_t* bias, const int32_t* nnz_per_row,
43
+ const int16_t* rhs_indices, int start_row, int end_row,
44
+ bool relu, int shift_out, int32_t* output);
45
+
46
+ } // namespace detail
47
+ } // namespace csrblocksparse
48
+
49
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
sparse_matmul/compute/matmul_generic.cc ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/matmul_generic.h"
16
+
17
+ #include <cstdint>
18
+ #include <vector>
19
+
20
+ #include "sparse_matmul/compute/matmul.h"
21
+
22
+ namespace csrblocksparse {
23
+ namespace detail {
24
+
25
+ void MatVecFloatGeneric(const float* weights, const float* rhs,
26
+ const float* bias, const int32_t* nnz_per_row,
27
+ const int16_t* rhs_indices, int start_row, int end_row,
28
+ int block_height, int block_width, bool relu,
29
+ int replicas, int stride, float* output) {
30
+ int weight_index = 0;
31
+ int bias_index = 0;
32
+ std::vector<float> accumulators(block_height);
33
+ for (int row_block = start_row; row_block < end_row;
34
+ ++row_block, output += block_height) {
35
+ int nnz = nnz_per_row[row_block];
36
+ // Biases are now stored and used directly without pre-division.
37
+ for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
38
+
39
+ for (int c = 0; c < nnz; ++c) {
40
+ int rhs_index = rhs_indices[c];
41
+ const float* block_rhs = rhs + rhs_index * block_width;
42
+ // Multiply this |block_height| x |block_width| block.
43
+ for (int i = 0; i < block_height; ++i) {
44
+ for (int j = 0; j < block_width; ++j) {
45
+ accumulators[i] += weights[weight_index++] * block_rhs[j];
46
+ }
47
+ }
48
+ }
49
+ rhs_indices += nnz;
50
+ // Apply relu if desired.
51
+ if (relu) {
52
+ for (int i = 0; i < block_height; ++i) {
53
+ if (accumulators[i] < 0) accumulators[i] = 0;
54
+ }
55
+ }
56
+ for (int r = 0; r < replicas; ++r) {
57
+ for (int i = 0; i < block_height; ++i) {
58
+ output[i + r * stride] = accumulators[i];
59
+ }
60
+ }
61
+ }
62
+ }
63
+
64
+ void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
65
+ const int32_t* bias, const int32_t* nnz_per_row,
66
+ const int16_t* rhs_indices, int start_row, int end_row,
67
+ int block_height, int block_width, bool relu,
68
+ int bytes_out, int shift_out, int replicas, int stride,
69
+ void* output) {
70
+ int weight_index = 0;
71
+ int bias_index = 0;
72
+ std::vector<int32_t> accumulators(block_height);
73
+ for (int row_block = start_row; row_block < end_row; ++row_block) {
74
+ int nnz = nnz_per_row[row_block];
75
+ // Biases are now stored and used directly without pre-division.
76
+ for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
77
+
78
+ for (int c = 0; c < nnz; ++c) {
79
+ int rhs_index = rhs_indices[c];
80
+ const int16_t* block_rhs = rhs + rhs_index * block_width;
81
+ // Multiply this |block_height| x |block_width| block.
82
+ for (int i = 0; i < block_height; ++i) {
83
+ for (int j = 0; j < block_width; ++j) {
84
+ accumulators[i] += weights[weight_index++] * block_rhs[j];
85
+ }
86
+ }
87
+ }
88
+ rhs_indices += nnz;
89
+ // Apply relu if desired.
90
+ if (relu) {
91
+ for (int i = 0; i < block_height; ++i) {
92
+ if (accumulators[i] < 0) accumulators[i] = 0;
93
+ }
94
+ }
95
+ // Output shift.
96
+ if (shift_out > 0) {
97
+ for (int i = 0; i < block_height; ++i) {
98
+ accumulators[i] >>= shift_out;
99
+ }
100
+ }
101
+ if (bytes_out == 2) {
102
+ int16_t* out16 = reinterpret_cast<int16_t*>(output);
103
+ output = out16 + block_height;
104
+ for (int r = 0; r < replicas; ++r, out16 += stride) {
105
+ for (int i = 0; i < block_height; ++i) {
106
+ out16[i] = accumulators[i];
107
+ }
108
+ }
109
+ } else {
110
+ int32_t* out32 = reinterpret_cast<int32_t*>(output);
111
+ output = out32 + block_height;
112
+ for (int r = 0; r < replicas; ++r, out32 += stride) {
113
+ for (int i = 0; i < block_height; ++i) {
114
+ out32[i] = accumulators[i];
115
+ }
116
+ }
117
+ }
118
+ }
119
+ }
120
+
121
+ } // namespace detail
122
+ } // namespace csrblocksparse
sparse_matmul/compute/matmul_generic.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
19
+
20
+ #include <cstdint>
21
+
22
+ namespace csrblocksparse {
23
+ namespace detail {
24
+
25
+ // Generic version uses plain C++ code.
26
+ void MatVecFloatGeneric(const float* weights, const float* rhs,
27
+ const float* bias, const int32_t* nnz_per_row,
28
+ const int16_t* rhs_indices, int start_row, int end_row,
29
+ int block_height, int block_width, bool relu,
30
+ int replicas, int stride, float* output);
31
+ void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
32
+ const int32_t* bias, const int32_t* nnz_per_row,
33
+ const int16_t* rhs_indices, int start_row, int end_row,
34
+ int block_height, int block_width, bool relu,
35
+ int bytes_out, int shift_out, int replicas, int stride,
36
+ void* output);
37
+
38
+ } // namespace detail
39
+ } // namespace csrblocksparse
40
+
41
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
sparse_matmul/compute/thread_bounds.cc ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/compute/thread_bounds.h"
16
+
17
+ #include <vector>
18
+
19
+ #include "glog/logging.h"
20
+
21
+ namespace csrblocksparse {
22
+
23
+ void ThreadBounds::PrepareForThreads(int block_width, int block_height,
24
+ int num_threads,
25
+ int reduced_rows_per_cache_row,
26
+ int reduced_rows, const int* nnz_per_row) {
27
+ CHECK_GT(num_threads, 0);
28
+ block_width_ = block_width;
29
+ block_height_ = block_height;
30
+ ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row,
31
+ reduced_rows, nnz_per_row);
32
+ weight_starts_.clear();
33
+ rhs_indices_starts_.clear();
34
+ bias_starts_.clear();
35
+ weight_starts_.reserve(row_starts_.size());
36
+ rhs_indices_starts_.reserve(row_starts_.size());
37
+ bias_starts_.reserve(row_starts_.size());
38
+
39
+ // Compute the start indices of each of the types, given what we know about
40
+ // padding, and number of |nnz_per_row|.
41
+ int weight_index = 0;
42
+ int rhs_indices_index = 0;
43
+ int bias_index = 0;
44
+ int row = 0;
45
+ for (int start : row_starts_) {
46
+ while (row < start) {
47
+ weight_index += nnz_per_row[row] * block_width_ * block_height_;
48
+ rhs_indices_index += nnz_per_row[row];
49
+ bias_index += block_height_;
50
+ ++row;
51
+ }
52
+ weight_starts_.push_back(weight_index);
53
+ rhs_indices_starts_.push_back(rhs_indices_index);
54
+ bias_starts_.push_back(bias_index);
55
+ }
56
+ }
57
+
58
+ // Computes the block row (reduced) index of the start of each thread.
59
+ void ThreadBounds::ComputeThreadSplitPoints(int num_threads,
60
+ int reduced_rows_per_cache_row,
61
+ int reduced_rows,
62
+ const int* nnz_per_row) {
63
+ row_starts_.assign(/*n=*/1, /*val=*/0);
64
+ // Break the rule if the matrix is too small to allow one per thread, which
65
+ // occurs only during tests.
66
+ if (reduced_rows_per_cache_row * num_threads > reduced_rows)
67
+ reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1);
68
+ int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) /
69
+ reduced_rows_per_cache_row;
70
+
71
+ // Compute exclusive prefix sum of the amount of work per row.
72
+ std::vector<int> work_upto_row(cache_rows + 1, 0);
73
+ int extra_row_work = 2 * reduced_rows_per_cache_row;
74
+ for (int i = 0; i < cache_rows; ++i) {
75
+ int new_nnz = 0;
76
+ for (int j = 0; j < reduced_rows_per_cache_row; ++j) {
77
+ // if |reduced_rows_per_cache_row| isn't an exact multiple of the
78
+ // matrix size, then we need to be careful here.
79
+ int index = i * reduced_rows_per_cache_row + j;
80
+ if (index < reduced_rows) new_nnz += nnz_per_row[index];
81
+ }
82
+ work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i];
83
+ }
84
+ int total_work = work_upto_row.back();
85
+ // Find the split point point based on assigned approximately equal amount
86
+ // of work for each thread.
87
+ int prev_split = 0;
88
+ for (int i = 1; i <= num_threads; ++i) {
89
+ int split = std::distance(
90
+ work_upto_row.begin(),
91
+ std::lower_bound(work_upto_row.begin(), work_upto_row.end(),
92
+ i * total_work / num_threads));
93
+ int split_row = split * reduced_rows_per_cache_row;
94
+ if (i == num_threads) {
95
+ split_row = reduced_rows;
96
+ }
97
+
98
+ VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back()
99
+ << " work=" << work_upto_row[split] - work_upto_row[prev_split];
100
+ row_starts_.push_back(split_row);
101
+ prev_split = split;
102
+ }
103
+ VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work;
104
+ }
105
+
106
+ } // namespace csrblocksparse
sparse_matmul/compute/thread_bounds.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
19
+
20
+ #include <vector>
21
+
22
+ namespace csrblocksparse {
23
+
24
+ // Class to compute and store the bounds of each thread used in a computation,
25
+ // and to provide corresponding spans of vectors.
26
+ class ThreadBounds {
27
+ public:
28
+ ThreadBounds() : block_width_(0), block_height_(0) {}
29
+
30
+ void PrepareForThreads(int block_width, int block_height, int num_threads,
31
+ int reduced_rows_per_cache_row, int reduced_rows,
32
+ const int* nnz_per_row);
33
+
34
+ // Functions that offset the appropriate type to the start of the data
35
+ // needed by the given thread id (|tid|).
36
+ template <typename WeightType>
37
+ const WeightType* OffsetWeights(const WeightType* weights, int tid) const {
38
+ return weights + weight_starts_[tid];
39
+ }
40
+ template <typename RhsIndType>
41
+ const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices,
42
+ int tid) const {
43
+ return rhs_indices + rhs_indices_starts_[tid];
44
+ }
45
+ template <typename BiasType>
46
+ const BiasType* OffsetBias(const BiasType* bias, int tid) const {
47
+ return bias + bias_starts_[tid];
48
+ }
49
+ template <typename OutType>
50
+ OutType* OffsetOutput(OutType* output, int tid) const {
51
+ return output + block_height_ * row_starts_[tid];
52
+ }
53
+ int StartRow(int tid) const { return row_starts_[tid]; }
54
+ const std::vector<int>& row_starts() const { return row_starts_; }
55
+
56
+ private:
57
+ // Computes the block row (reduced) index of the start of each thread.
58
+ void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row,
59
+ int reduced_rows, const int* nnz_per_row);
60
+
61
+ // Sizes of a sparse block.
62
+ int block_width_;
63
+ int block_height_;
64
+ // Start indices of each data type by thread-id with an extra value at the
65
+ // end.
66
+ std::vector<int> row_starts_;
67
+ std::vector<int> weight_starts_;
68
+ std::vector<int> rhs_indices_starts_;
69
+ std::vector<int> bias_starts_;
70
+ };
71
+
72
+ } // namespace csrblocksparse
73
+
74
+ #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
sparse_matmul/layers/BUILD ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sparse/Masked Matrix and Layer.
2
+
3
+ # [internal] load android_library_selector
4
+ # [internal] load android_cc_test:def.bzl
5
+
6
+ licenses(["notice"])
7
+
8
+ cc_library(
9
+ name = "layer",
10
+ hdrs = [
11
+ "sparse_linear_layer.h",
12
+ ],
13
+ visibility = [
14
+ "//sparse_matmul:__subpackages__",
15
+ ],
16
+ deps = [
17
+ ":matrix",
18
+ "//sparse_matmul/numerics:types",
19
+ "//sparse_matmul/os:coop_threads",
20
+ "//sparse_matmul/vector:cache_aligned_vector",
21
+ "@com_google_absl//absl/memory",
22
+ "@com_google_absl//absl/strings:str_format",
23
+ "@com_google_glog//:glog",
24
+ ],
25
+ )
26
+
27
+ cc_library(
28
+ name = "matrix",
29
+ hdrs = [
30
+ "csr_blocksparse_matrix.h",
31
+ "masked_sparse_matrix.h",
32
+ ],
33
+ visibility = [
34
+ "//sparse_matmul:__subpackages__",
35
+ ],
36
+ deps = [
37
+ "//sparse_matmul/compute:kernels",
38
+ "//sparse_matmul/compute:matmul",
39
+ "//sparse_matmul/compute:thread_bounds",
40
+ "//sparse_matmul/numerics:types",
41
+ "//sparse_matmul/os:coop_threads",
42
+ "//sparse_matmul/vector:cache_aligned_vector",
43
+ "@com_google_absl//absl/memory",
44
+ "@com_google_absl//absl/strings:str_format",
45
+ "@com_google_glog//:glog",
46
+ ],
47
+ )
48
+
49
+ cc_library(
50
+ name = "utils",
51
+ srcs = [
52
+ "utils.cc",
53
+ ],
54
+ hdrs = [
55
+ "read_array_ifstream.h",
56
+ "utils.h",
57
+ ],
58
+ visibility = [
59
+ "//sparse_matmul:__subpackages__",
60
+ ],
61
+ deps = [
62
+ ":layer",
63
+ ":matrix",
64
+ ":status",
65
+ "//sparse_matmul/numerics:types",
66
+ "//sparse_matmul/vector:cache_aligned_vector",
67
+ "//sparse_matmul/zlib_wrapper",
68
+ "@com_google_absl//absl/status",
69
+ "@com_google_absl//absl/strings",
70
+ "@com_google_absl//absl/strings:cord",
71
+ "@gulrak_filesystem//:filesystem",
72
+ ],
73
+ )
74
+
75
+ cc_library(
76
+ name = "status",
77
+ srcs = [
78
+ "errno_mapping.cc",
79
+ ],
80
+ hdrs = [
81
+ "errno_mapping.h",
82
+ "status_macros.h",
83
+ ],
84
+ deps = [
85
+ "@com_google_absl//absl/status",
86
+ "@com_google_absl//absl/status:statusor",
87
+ "@com_google_absl//absl/strings",
88
+ "@com_google_absl//absl/strings:cord",
89
+ ],
90
+ )
91
+
92
+ cc_test(
93
+ name = "csrblocksparse_test",
94
+ size = "small",
95
+ srcs = [
96
+ "csrblocksparse_test.cc",
97
+ ],
98
+ data = glob(["testdata/*"]),
99
+ linkopts = select({
100
+ "@bazel_tools//platforms:android": ["-landroid"],
101
+ "//conditions:default": [],
102
+ }),
103
+ shard_count = 10,
104
+ deps = [
105
+ ":status",
106
+ ":utils",
107
+ "//sparse_matmul/compute:matmul",
108
+ "//sparse_matmul/numerics:test_utils",
109
+ "//sparse_matmul/os:coop_threads",
110
+ "@com_google_absl//absl/status",
111
+ "@com_google_absl//absl/strings",
112
+ "@com_google_absl//absl/types:span",
113
+ "@com_google_googletest//:gtest_main",
114
+ "@gulrak_filesystem//:filesystem",
115
+ ],
116
+ )
117
+
118
+ cc_test(
119
+ name = "sparse_linear_layer_test",
120
+ srcs = [
121
+ "sparse_linear_layer_test.cc",
122
+ ],
123
+ deps = [
124
+ ":layer",
125
+ "//sparse_matmul/numerics:test_utils",
126
+ "@com_google_googletest//:gtest_main",
127
+ ],
128
+ )
129
+
130
+ cc_test(
131
+ name = "utils_test",
132
+ srcs = ["utils_test.cc"],
133
+ deps = [
134
+ ":layer",
135
+ ":matrix",
136
+ ":status",
137
+ ":utils",
138
+ "//sparse_matmul/numerics:fast_transcendentals",
139
+ "//sparse_matmul/numerics:test_utils",
140
+ "//sparse_matmul/numerics:types",
141
+ "//sparse_matmul/vector:cache_aligned_vector",
142
+ "@com_google_absl//absl/flags:flag",
143
+ "@com_google_googletest//:gtest_main",
144
+ "@gulrak_filesystem//:filesystem",
145
+ ],
146
+ )
sparse_matmul/layers/csr_blocksparse_matrix.h ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
19
+
20
+ #include <algorithm>
21
+ #include <cstdint>
22
+ #include <iostream>
23
+ #include <memory>
24
+ #include <tuple>
25
+ #include <vector>
26
+
27
+ #include "glog/logging.h"
28
+ // IWYU pragma: begin_exports
29
+ #include "sparse_matmul/compute/kernels_generic.h"
30
+ #include "sparse_matmul/compute/matmul.h"
31
+ #include "sparse_matmul/compute/thread_bounds.h"
32
+ #include "sparse_matmul/layers/masked_sparse_matrix.h"
33
+ #include "sparse_matmul/numerics/fixed_types.h"
34
+ #include "sparse_matmul/numerics/float16_types.h"
35
+ #include "sparse_matmul/os/coop_threads.h"
36
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
37
+ // IWYU pragma: end_exports
38
+ #include "absl/memory/memory.h"
39
+
40
+ namespace csrblocksparse {
41
+ // CsrBlockSparseMatrix stores a modified block compressed sparse row
42
+ // representation of a sparse matrix. The ordering of the weights is modified
43
+ // in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively)
44
+ // of columns of weights are stored contiguously before moving on to the next
45
+ // row. The 4x4 case stores each block contiguously.
46
+ //
47
+ // Currently it is constructed from a MaskedSparseMatrix which usees a dense
48
+ // binary mask representation. The construction generates the compressed
49
+ // representation. Further iterations will support a direct serialization
50
+ // of the compressed representation.
51
+ //
52
+ // MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values)
53
+ // CsrBlockSparseMatrix matrix(masked_matrix)
54
+ //
55
+ // matrix.SpMV_bias(rhs, bias, &out);
56
+ //
57
+ // This class is thread compatible.
58
+ template <typename WeightType, typename RhsType, typename DeltaType = int16_t>
59
+ class CsrBlockSparseMatrix {
60
+ public:
61
+ CsrBlockSparseMatrix() {}
62
+
63
+ // Reference used to indicate that this is an input and not an output.
64
+ CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) {
65
+ ReadFromFlatBuffer(buffer, len);
66
+ ComputeRHSIndices();
67
+ }
68
+
69
+ template <typename InputType>
70
+ CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) {
71
+ sparsity_ = masked_matrix.sparsity();
72
+ rows_ = masked_matrix.rows();
73
+ cols_ = masked_matrix.cols();
74
+
75
+ DetermineBlockSize(masked_matrix);
76
+
77
+ if (block_width_ == 1 && block_height_ == 1)
78
+ col_multiple_ = 8;
79
+ else
80
+ col_multiple_ = 1;
81
+
82
+ std::vector<InputType> weights(masked_matrix.values().begin(),
83
+ masked_matrix.values().end());
84
+
85
+ reduced_rows_ = (rows_ + block_height_ - 1) / block_height_;
86
+ rows_ = reduced_rows_ * block_height_;
87
+ reduced_cols_ = cols_ / block_width_;
88
+
89
+ // Calculate the reduced CSR representation of the matrix.
90
+ std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_);
91
+ std::vector<int> row_offsets = {0};
92
+ int nnz = 0;
93
+ const auto& mask = masked_matrix.mask();
94
+ for (int r = 0; r < reduced_rows_; ++r) {
95
+ for (int c = 0; c < reduced_cols_; ++c) {
96
+ int mask_val = mask[r * block_height_ * cols_ + c * block_width_];
97
+ reduced_mask[r * reduced_cols_ + c] = mask_val;
98
+ nnz += mask_val;
99
+ }
100
+ row_offsets.push_back(nnz);
101
+ }
102
+
103
+ // Make sure the reduced representation has the correct number of columns.
104
+ MakeColumnsMultiple(row_offsets, &reduced_mask, &weights);
105
+
106
+ std::vector<int> col_indices;
107
+ std::vector<WeightType> weights_csr;
108
+ std::vector<int> nnz_per_row;
109
+ MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices,
110
+ &weights_csr);
111
+
112
+ // Generate column deltas from |col_indices|.
113
+ std::vector<DeltaType> col_deltas;
114
+ for (int i = 0; i < col_indices.size(); ++i) {
115
+ // |col_indices| are used to index the RHS vector which is always float.
116
+ int64_t diff = sizeof(RhsType);
117
+ if (i == 0)
118
+ diff *= block_width_ * (col_indices[i]);
119
+ else
120
+ diff *= block_width_ * (col_indices[i] - col_indices[i - 1]);
121
+
122
+ CHECK(diff < std::numeric_limits<DeltaType>::max())
123
+ << "delta between column indices in bytes " << diff
124
+ << " exceeded the maximum size of the DeltaType "
125
+ << std::numeric_limits<DeltaType>::max();
126
+ col_deltas.push_back(static_cast<DeltaType>(diff));
127
+ }
128
+
129
+ // Because of pre-fetching we need some extra values at the end.
130
+ col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0);
131
+ nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back());
132
+
133
+ weights_ = CacheAlignedVector<WeightType>(weights_csr);
134
+ col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
135
+ nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row);
136
+ ComputeRHSIndices();
137
+
138
+ num_threads_ = 0;
139
+ PrepareForThreads(1);
140
+ }
141
+
142
+ // Constructor makes a matrix from the given weights, deltas and nnz, taking
143
+ // the other parameters from |src_matrix|. |cols| is the number of raw columns
144
+ // (NOT blocks) of the new matrix.
145
+ CsrBlockSparseMatrix(
146
+ const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix,
147
+ const std::vector<WeightType>& new_weights,
148
+ const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz,
149
+ int cols) {
150
+ num_threads_ = 0;
151
+ col_multiple_ = src_matrix.col_multiple_;
152
+ block_width_ = src_matrix.block_width_;
153
+ block_height_ = src_matrix.block_height_;
154
+ reduced_rows_ = new_nnz.size();
155
+ rows_ = reduced_rows_ * block_height_;
156
+ cols_ = cols;
157
+ reduced_cols_ = cols_ / block_width_;
158
+ weights_ = CacheAlignedVector<WeightType>(new_weights);
159
+ col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas);
160
+ nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
161
+ sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
162
+ ComputeRHSIndices();
163
+ name_ = src_matrix.name_;
164
+ PrepareForThreads(1);
165
+ }
166
+
167
+ // Factory method takes a column slice out of *this and returns a sparse
168
+ // matrix that takes as inputs [|start_col|, |end_col|) of *this, and
169
+ // returns the same number of outputs, but only a partial result.
170
+ // If |keep_rhs_size|, then the new matrix takes the same rhs as the current
171
+ // matrix, but uses a subset of it, instead of expecting just the reduced rhs.
172
+ // If |start_col| > |end_col|, then we slice out the complement of the defined
173
+ // interval, ie [0, |end_col|) + [|start_col|, current end).
174
+ // NOTE That |start_col| and |end_col| are in raw column coordinates, NOT
175
+ // block units.
176
+ CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col,
177
+ bool keep_rhs_size = false) const {
178
+ int weight_index = 0;
179
+ int delta_index = 0;
180
+ std::vector<DeltaType> new_deltas;
181
+ std::vector<WeightType> new_weights;
182
+ std::vector<int> new_nnz(reduced_rows_);
183
+ int col = 0;
184
+ int prev_col = keep_rhs_size ? 0 : start_col;
185
+ for (int r = 0; r < reduced_rows_; ++r) {
186
+ int reduced_col_count = nnz_per_row_[r];
187
+ for (int c = 0; c < reduced_col_count; ++c, ++delta_index) {
188
+ col += col_deltas_[delta_index] / sizeof(RhsType);
189
+ if ((start_col < end_col && start_col <= col && col < end_col) ||
190
+ (start_col > end_col && (col < end_col || col >= start_col))) {
191
+ ++new_nnz[r];
192
+ new_deltas.push_back((col - prev_col) * sizeof(RhsType));
193
+ prev_col = col;
194
+ for (int i = 0; i < block_width_ * block_height_;
195
+ ++i, ++weight_index) {
196
+ new_weights.push_back(weights_[weight_index]);
197
+ }
198
+ } else {
199
+ weight_index += block_width_ * block_height_;
200
+ }
201
+ }
202
+ }
203
+ int new_cols = keep_rhs_size ? cols_ : end_col - start_col;
204
+ return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz,
205
+ new_cols);
206
+ }
207
+
208
+ // Factory method takes a row slice out of *this and returns a sparse
209
+ // matrix that takes the sampe inputs as *this, and returns the outputs for
210
+ // the range [|start_row|, |end_row|).
211
+ // NOTE That |start_row| and |end_row| are in raw column coordinates, NOT
212
+ // block units.
213
+ CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const {
214
+ int start_reduced = start_row / block_height_;
215
+ int end_reduced = end_row / block_height_;
216
+ std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced,
217
+ nnz_per_row_.data() + end_reduced);
218
+ int weight_start = 0;
219
+ for (int r = 0; r < start_reduced; ++r) {
220
+ weight_start += nnz_per_row_[r];
221
+ }
222
+ int weight_end = weight_start;
223
+ for (int r = start_reduced; r < end_reduced; ++r) {
224
+ weight_end += nnz_per_row_[r];
225
+ }
226
+ int delta_start = 0;
227
+ for (int i = 0; i < weight_start; ++i) {
228
+ delta_start += col_deltas_[i];
229
+ }
230
+ std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start,
231
+ col_deltas_.data() + weight_end);
232
+ new_deltas[0] += delta_start;
233
+ int block_size = block_height_ * block_width_;
234
+ std::vector<WeightType> new_weights(
235
+ weights_.data() + weight_start * block_size,
236
+ weights_.data() + weight_end * block_size);
237
+ return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_);
238
+ }
239
+
240
+ // Combines adjacent row blocks, doubling the block height.
241
+ // This necessarily involves adding zero weights where the blocks don't align
242
+ // across adjacent pairs of rows, so use with caution, as the resulting matrix
243
+ // is most likely to run slower if very sparse to begin with.
244
+ // In the few cases where the blocks do mostly align, the resulting matmul
245
+ // could be much faster, as the number of reads of the rhs will be halved.
246
+ void DoubleBlockHeight() {
247
+ int new_rows = reduced_rows_ / 2;
248
+ std::vector<int> new_nnz(new_rows);
249
+ std::vector<DeltaType> new_rhs_indices;
250
+ std::vector<WeightType> new_weights;
251
+ int rhs_index1 = 0;
252
+ int rhs_index2 = 0;
253
+ int block_size = block_height_ * block_width_;
254
+ for (int r = 0; r < new_rows; ++r) {
255
+ int start_nnz = new_rhs_indices.size();
256
+ rhs_index2 += nnz_per_row_[r * 2];
257
+ int end1 = rhs_index1 + nnz_per_row_[r * 2];
258
+ int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1];
259
+ // Run over a pair of rows with 2 iterators, combining blocks as we go, or
260
+ // padding with zeros where the block positions don't match.
261
+ while (rhs_index1 < end1 || rhs_index2 < end2) {
262
+ int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_;
263
+ int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_;
264
+ if (col1 < col2) {
265
+ // Need zero weights for row2 to pad out weights block.
266
+ new_rhs_indices.push_back(col1);
267
+ new_weights.insert(new_weights.end(),
268
+ weights_.data() + rhs_index1 * block_size,
269
+ weights_.data() + (rhs_index1 + 1) * block_size);
270
+ new_weights.insert(new_weights.end(), block_size,
271
+ static_cast<WeightType>(0.0f));
272
+ ++rhs_index1;
273
+ } else if (col1 > col2) {
274
+ // Need zero weights for row1 to pad out weights block.
275
+ new_rhs_indices.push_back(col2);
276
+ new_weights.insert(new_weights.end(), block_size,
277
+ static_cast<WeightType>(0.0f));
278
+ new_weights.insert(new_weights.end(),
279
+ weights_.data() + rhs_index2 * block_size,
280
+ weights_.data() + (rhs_index2 + 1) * block_size);
281
+ ++rhs_index2;
282
+ } else {
283
+ // Combine weights for both row1 and row2.
284
+ new_rhs_indices.push_back(col1);
285
+ new_weights.insert(new_weights.end(),
286
+ weights_.data() + rhs_index1 * block_size,
287
+ weights_.data() + (rhs_index1 + 1) * block_size);
288
+ new_weights.insert(new_weights.end(),
289
+ weights_.data() + rhs_index2 * block_size,
290
+ weights_.data() + (rhs_index2 + 1) * block_size);
291
+ ++rhs_index1;
292
+ ++rhs_index2;
293
+ }
294
+ }
295
+ rhs_index1 = rhs_index2;
296
+ new_nnz[r] = new_rhs_indices.size() - start_nnz;
297
+ }
298
+ block_height_ *= 2;
299
+ reduced_rows_ /= 2;
300
+ weights_ = CacheAlignedVector<WeightType>(new_weights);
301
+ rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices);
302
+ nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
303
+ sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
304
+ ComputeColDeltas();
305
+ if (num_threads_ > 0) {
306
+ int num_threads = num_threads_;
307
+ num_threads_ = 0;
308
+ PrepareForThreads(num_threads);
309
+ }
310
+ }
311
+
312
+ // Allocates memory and fills buffer.
313
+ // Caller is responsible for the memory de-allocation.
314
+ // TODO(b/189958858): Both Read and Write need to eventually handle the
315
+ // different possible HalfType and DeltaType values, but punting for now as
316
+ // there is only one supported combination.
317
+ std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) {
318
+ std::size_t bytes = 0;
319
+ bytes += FixedParameterSize();
320
+ bytes += weights_.size() * sizeof(WeightType);
321
+ bytes += col_deltas_.size() * sizeof(DeltaType);
322
+ bytes += nnz_per_row_.size() * sizeof(int);
323
+
324
+ uint8_t* bytes_ptr_ptr =
325
+ reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes)));
326
+
327
+ int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr);
328
+
329
+ *int_bytes_ptr++ = rows_;
330
+ *int_bytes_ptr++ = cols_;
331
+ *int_bytes_ptr++ = reduced_rows_;
332
+ *int_bytes_ptr++ = reduced_cols_;
333
+ *int_bytes_ptr++ = block_width_;
334
+ *int_bytes_ptr++ = block_height_;
335
+ *int_bytes_ptr++ = col_multiple_;
336
+ *int_bytes_ptr++ = num_threads_;
337
+ *int_bytes_ptr++ = weights_.size();
338
+ *int_bytes_ptr++ = col_deltas_.size();
339
+ *int_bytes_ptr++ = nnz_per_row_.size();
340
+
341
+ float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr);
342
+ *float_bytes_ptr++ = sparsity_;
343
+
344
+ uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr);
345
+
346
+ memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType));
347
+ bytes_ptr += weights_.size() * sizeof(WeightType);
348
+
349
+ memcpy(bytes_ptr, col_deltas_.data(),
350
+ col_deltas_.size() * sizeof(DeltaType));
351
+ bytes_ptr += col_deltas_.size() * sizeof(DeltaType);
352
+
353
+ memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int));
354
+ bytes_ptr += nnz_per_row_.size() * sizeof(int);
355
+
356
+ csr_flatbuffer->resize(bytes);
357
+ csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes);
358
+ free(bytes_ptr_ptr);
359
+
360
+ return bytes;
361
+ }
362
+
363
+ void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) {
364
+ CHECK_GE(len, FixedParameterSize());
365
+
366
+ const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes);
367
+ rows_ = *int_bytes_ptr++;
368
+ cols_ = *int_bytes_ptr++;
369
+ reduced_rows_ = *int_bytes_ptr++;
370
+ reduced_cols_ = *int_bytes_ptr++;
371
+ block_width_ = *int_bytes_ptr++;
372
+ block_height_ = *int_bytes_ptr++;
373
+ col_multiple_ = *int_bytes_ptr++;
374
+ int num_threads = *int_bytes_ptr++;
375
+ int32_t weights_size = *int_bytes_ptr++;
376
+ int32_t col_deltas_size = *int_bytes_ptr++;
377
+ int32_t nnz_per_row_size = *int_bytes_ptr++;
378
+
379
+ // Make sure negative sizes don't mess things up.
380
+ weights_size = std::max(0, weights_size);
381
+ col_deltas_size = std::max(0, col_deltas_size);
382
+ nnz_per_row_size = std::max(0, nnz_per_row_size);
383
+
384
+ const float* float_bytes_ptr =
385
+ reinterpret_cast<const float*>(int_bytes_ptr);
386
+ sparsity_ = *float_bytes_ptr++;
387
+
388
+ std::size_t total_bytes =
389
+ FixedParameterSize() + weights_size * sizeof(WeightType) +
390
+ col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int);
391
+
392
+ CHECK_EQ(total_bytes, len)
393
+ << "total bytes: " << total_bytes << ", actual len given: " << len;
394
+
395
+ const uint8_t* bytes_ptr =
396
+ reinterpret_cast<const uint8_t*>(float_bytes_ptr);
397
+ std::vector<WeightType> weights_raw(weights_size);
398
+ memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType));
399
+ weights_ = CacheAlignedVector<WeightType>(weights_raw);
400
+ bytes_ptr += weights_size * sizeof(WeightType);
401
+
402
+ std::vector<DeltaType> deltas_raw(col_deltas_size);
403
+ memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType));
404
+ col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw);
405
+ bytes_ptr += col_deltas_size * sizeof(DeltaType);
406
+
407
+ std::vector<int> nnz_raw(nnz_per_row_size);
408
+ memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int));
409
+ nnz_per_row_ = CacheAlignedVector<int>(nnz_raw);
410
+ num_threads_ = 0;
411
+ PrepareForThreads(num_threads);
412
+ }
413
+
414
+ // Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is
415
+ // a vector with a small number of columns, hence the term "fat vector".
416
+ // 1x1 and 4x4 have specializations for output columns (ie fatness) > 5,
417
+ // and often achieve twice as many GFlops when multiplying a right hand side
418
+ // that has 5 or more columns. (Best is a multiple of 5).
419
+ // 16x1 doesn't have enough registers and just loops over the width 1 kernel.
420
+ //
421
+ // |rhs| and |out| are COLUMN MAJOR.
422
+
423
+ // Fast Tuples WeightType, BiasType, RhsType, OutType are:
424
+ // (float, float, float, float)
425
+ // (bfloat16, float, float, float)
426
+ // and only on ARM64. All other cases use a slow generic implementation.
427
+ template <typename RhsClass, typename BiasClass, typename OutClass,
428
+ typename BiasType = typename BiasClass::value_type,
429
+ typename OutType = typename OutClass::value_type>
430
+ void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
431
+ bool relu = false, int tid = 0,
432
+ SpinBarrier* barrier = nullptr) const {
433
+ static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value,
434
+ "Rhs types must match");
435
+ CHECK_LT(tid, num_threads_);
436
+ CHECK_EQ(rhs.cols(), out->cols());
437
+ CHECK_EQ(rhs.rows(), cols_);
438
+ CHECK_GE(out->rows(), rows_);
439
+ int cols_to_go = out->cols();
440
+ int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid);
441
+ const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_;
442
+ OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid);
443
+ const WeightType* weights_ptr =
444
+ thread_bounds_.OffsetWeights(weights_.data(), tid);
445
+ const DeltaType* delta_ptr =
446
+ thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid);
447
+ int offset = *delta_ptr / sizeof(RhsType);
448
+ rhs_ptr -= offset;
449
+ const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid);
450
+ int assigned_rows =
451
+ thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid);
452
+ const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid);
453
+
454
+ while (cols_to_go > 0) {
455
+ if (block_width_ == 4 && block_height_ == 4) {
456
+ if (cols_to_go >= 5) {
457
+ detail::SpMM5_4x4<WeightType, RhsType, OutType>(
458
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
459
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
460
+ } else {
461
+ detail::SpMV_4x4<WeightType, RhsType, OutType>(
462
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
463
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
464
+ }
465
+ } else {
466
+ if (cols_to_go >= 5) {
467
+ detail::SpMM5_1x1<WeightType, RhsType, OutType>(
468
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
469
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
470
+ } else {
471
+ detail::SpMV_1x1<WeightType, RhsType, OutType>(
472
+ weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
473
+ assigned_rows, out->col_stride(), rhs.col_stride(), relu);
474
+ }
475
+ }
476
+
477
+ if (cols_to_go >= 5) {
478
+ cols_to_go -= 5;
479
+ rhs_ptr += rhs.col_stride() * 5;
480
+ out_ptr += out->col_stride() * 5;
481
+ } else {
482
+ cols_to_go--;
483
+ rhs_ptr += rhs.col_stride();
484
+ out_ptr += out->col_stride();
485
+ }
486
+ if (barrier) barrier->barrier();
487
+ }
488
+ }
489
+ template <typename MVRhsType, typename MVBiasType, typename OutType>
490
+ void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid,
491
+ int replicas, int output_stride, OutType* output) {
492
+ CHECK_LT(tid, num_threads_);
493
+ CHECK_EQ(block_width_, 4) << "Block width must be 4!";
494
+ if (block_height_ == 8) {
495
+ matmul_.MatVec8x4(
496
+ thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
497
+ thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
498
+ thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
499
+ thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
500
+ replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
501
+ } else {
502
+ CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!";
503
+ matmul_.MatVec4x4(
504
+ thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
505
+ thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
506
+ thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
507
+ thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
508
+ replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
509
+ }
510
+ }
511
+
512
+ int rows() const { return rows_; }
513
+ int cols() const { return cols_; }
514
+ int block_height() const { return block_height_; }
515
+ int block_width() const { return block_width_; }
516
+ float sparsity() const { return sparsity_; }
517
+ int num_threads() const { return num_threads_; }
518
+ const ThreadBounds& thread_bounds() const { return thread_bounds_; }
519
+ const CacheAlignedVector<DeltaType>& rhs_indices() const {
520
+ return rhs_indices_;
521
+ }
522
+ const std::string& name() const { return name_; }
523
+ void set_name(const std::string& name) { name_ = name; }
524
+ const std::vector<int>& split_points() const {
525
+ return thread_bounds_.row_starts();
526
+ }
527
+
528
+ std::size_t bytes() const {
529
+ return weights_.size() * sizeof(WeightType) +
530
+ col_deltas_.size() * sizeof(DeltaType) +
531
+ nnz_per_row_.size() * sizeof(int);
532
+ }
533
+
534
+ // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
535
+ // and then samples from the output (softmax distribution) layer.
536
+ template <typename RhsClass, typename BiasClass, typename OutClass,
537
+ typename BiasType = typename BiasClass::value_type,
538
+ typename OutType = typename OutClass::value_type>
539
+ typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type
540
+ SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
541
+ float temperature, int tid, SpinBarrier* barrier,
542
+ std::minstd_rand* gen,
543
+ CacheAlignedVector<float>* scratch) const {
544
+ SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier);
545
+ return out->Sample(temperature, gen, scratch);
546
+ }
547
+ // Fixed32 version.
548
+ template <typename RhsClass, typename BiasClass, typename OutClass,
549
+ typename BiasType = typename BiasClass::value_type,
550
+ typename OutType = typename OutClass::value_type>
551
+ typename std::enable_if<IsFixed32Type<OutType>::value, int>::type
552
+ SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
553
+ float temperature, int tid, SpinBarrier* barrier,
554
+ std::minstd_rand* gen,
555
+ CacheAlignedVector<float>* scratch) const {
556
+ // We don't pass the barrier on, as we have more work to do.
557
+ SpMM_bias(rhs, bias, out, /*relu=*/false, tid);
558
+ return out->ReducingSample(gen, scratch, tid, temperature, barrier);
559
+ }
560
+
561
+ void Print() const {
562
+ std::cout << "Weights\n";
563
+ weights_.Print();
564
+ std::cout << std::endl;
565
+ std::cout << "Deltas\n";
566
+ col_deltas_.Print();
567
+ std::cout << std::endl;
568
+ std::cout << "nnz\n";
569
+ nnz_per_row_.Print();
570
+ std::cout << std::endl;
571
+ }
572
+
573
+ // Split the computation amongst threads by rows based on the number of
574
+ // non zeros, with the addition of a constant to account for the work of the
575
+ // bias and the horizontal add at the end, and also guarantees that each
576
+ // thread writes only whole cache lines, based on the size of OutType.
577
+ // The |cache_line_size| arg is used only for testing. Normally it is provided
578
+ // through the architecture #defines.
579
+ // Each thread gets a contiguous row range (|split_points|).
580
+ // Thread t does rows [ split_points[t], split_points[t + 1] )
581
+ // Each thread also needs to know how many non zeros were before it to skip
582
+ // (|nnz_to_skip|). And finally it also needs to know what the offset into
583
+ // the rhs vector would have been at the split point (|rhs_to_skip|).
584
+ //
585
+ // Some tricky corner cases where the number of non-zeros doesn't split
586
+ // nicely amongst the number of requested threads are not handled and default
587
+ // to one thread; these cases are only going to happen in tests and not in
588
+ // the matrices that correspond in real models.
589
+ //
590
+ // Returns the maximum number of threads that can be used; <= |num_threads|.
591
+ template <typename OutType = int32_t>
592
+ int PrepareForThreads(int num_threads, int cache_line_size = -1) {
593
+ CHECK_GT(num_threads, 0);
594
+ // we've already prepared for this number of threads, nothing to do
595
+ if (num_threads == num_threads_) return num_threads_;
596
+
597
+ num_threads_ = num_threads;
598
+ thread_bounds_.PrepareForThreads(
599
+ block_width_, block_height_, num_threads_,
600
+ ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_,
601
+ nnz_per_row_.data());
602
+ return num_threads_;
603
+ }
604
+
605
+ // Computes and stores the |rhs_indices_| from the |col_deltas_|.
606
+ void ComputeRHSIndices() {
607
+ std::vector<int> cumulative_deltas = CumulativeColDeltas();
608
+ std::vector<DeltaType> rhs_indices(cumulative_deltas.size() +
609
+ reduced_rows_);
610
+ int total_indices = 0;
611
+ int delta_index = 0;
612
+ for (int r = 0; r < reduced_rows_; ++r) {
613
+ for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) {
614
+ rhs_indices[total_indices++] =
615
+ cumulative_deltas[delta_index] / block_width_;
616
+ }
617
+ }
618
+ rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices);
619
+ }
620
+
621
+ // Computes and stores the |col_deltas_| from the |rhs_indices_|.
622
+ void ComputeColDeltas() {
623
+ std::vector<int> col_deltas(rhs_indices_.size());
624
+ int prev_index = 0;
625
+ for (int i = 0; i < rhs_indices_.size(); ++i) {
626
+ int offset = rhs_indices_[i] - prev_index;
627
+ prev_index = rhs_indices_[i];
628
+ col_deltas[i] = offset * block_width_ * sizeof(RhsType);
629
+ }
630
+ col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
631
+ }
632
+
633
+ // Computes and returns the inclusive prefix sum of the deltas, ie absolute
634
+ // positions.
635
+ std::vector<int> CumulativeColDeltas() const {
636
+ std::vector<int> cum_col_deltas(col_deltas_.size());
637
+ for (int i = 0; i < col_deltas_.size(); ++i) {
638
+ cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType);
639
+ if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1];
640
+ }
641
+ return cum_col_deltas;
642
+ }
643
+
644
+ private:
645
+ constexpr std::size_t FixedParameterSize() const {
646
+ return sizeof(int) // rows
647
+ + sizeof(int) // cols
648
+ + sizeof(int) // reduced_rows
649
+ + sizeof(int) // reduced_cols
650
+ + sizeof(int) // block_width
651
+ + sizeof(int) // block_height
652
+ + sizeof(float) // sparsity
653
+ + sizeof(int) // col_multiple
654
+ + sizeof(int) // num_threads_
655
+ + sizeof(int) // weights_.size()
656
+ + sizeof(int) // col_deltas_.size()
657
+ + sizeof(int); // nnz_per_row_.size()
658
+ }
659
+ // Possible block sizes are only those that are supported by the computation
660
+ // default is 1x1, other options are 4x4 and 16x1.
661
+ template <typename InputType>
662
+ void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) {
663
+ const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}};
664
+ int rows = masked_matrix.rows();
665
+ int cols = masked_matrix.cols();
666
+
667
+ for (const auto& block_size : kPreferredOrder) {
668
+ int block_height, block_width;
669
+ std::tie(block_height, block_width) = block_size;
670
+ if (cols % block_width != 0) continue;
671
+
672
+ int reduced_rows = (rows + block_height - 1) / block_height;
673
+ int reduced_cols = cols / block_width;
674
+
675
+ // For each possible block, confirm that it is either all 0s or all 1s.
676
+ bool all_same = true;
677
+ const auto& mask = masked_matrix.mask();
678
+ for (int r = 0; r < reduced_rows; ++r) {
679
+ for (int c = 0; c < reduced_cols; ++c) {
680
+ int val = mask[r * block_height * cols + c * block_width];
681
+ for (int i = 0; i < block_height; ++i) {
682
+ for (int j = 0; j < block_width; ++j) {
683
+ int index = (r * block_height + i) * cols + c * block_width + j;
684
+ if (index < masked_matrix.mask().size()) {
685
+ all_same &= (masked_matrix.mask()[index] == val);
686
+ }
687
+ }
688
+ }
689
+ }
690
+ }
691
+
692
+ // If this block configuration is possible, accept it.
693
+ if (all_same) {
694
+ block_height_ = block_height;
695
+ block_width_ = block_width;
696
+ return;
697
+ }
698
+ }
699
+
700
+ // No large blocks were found, default to 1x1.
701
+ block_height_ = 1;
702
+ block_width_ = 1;
703
+ }
704
+
705
+ // CSR descriptors are for the reduced matrix, weights is the full matrix.
706
+ template <typename InputType>
707
+ void MakeColumnsMultiple(const std::vector<int>& row_offsets,
708
+ std::vector<int>* reduced_mask,
709
+ std::vector<InputType>* weights) {
710
+ if (col_multiple_ > 0) {
711
+ // Make sure each row has a number of columns that is a multiple of
712
+ // |col_multiple|.
713
+ for (int r = 1; r < row_offsets.size(); ++r) {
714
+ int num_row = row_offsets[r] - row_offsets[r - 1];
715
+ int num_needed = col_multiple_ - num_row % col_multiple_;
716
+ if (num_needed < col_multiple_) {
717
+ // Find gaps in the columns where we can insert a column of 0 weights.
718
+ int num_added = 0;
719
+ for (int c = 0; c < reduced_cols_; ++c) {
720
+ if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) {
721
+ (*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1;
722
+
723
+ // Zero out the weights that correspond to this block.
724
+ for (int i = 0; i < block_height_; ++i) {
725
+ for (int j = 0; j < block_width_; ++j) {
726
+ (*weights)[((r - 1) * block_height_ + i) * cols_ +
727
+ block_width_ * c + j] = InputType(0.f);
728
+ }
729
+ }
730
+ num_added++;
731
+ }
732
+
733
+ if (num_added == num_needed) break;
734
+ }
735
+ }
736
+ }
737
+ }
738
+ }
739
+
740
+ // Given the final dense mask and weights, convert to the compressed
741
+ // block CSR representation.
742
+ template <typename InputType>
743
+ void MaskAndWeightsToCsr(const std::vector<int>& mask,
744
+ const std::vector<InputType>& weights,
745
+ std::vector<int>* nnz_per_row,
746
+ std::vector<int>* col_indices,
747
+ std::vector<WeightType>* weights_csr) {
748
+ std::vector<int> row_offsets = {0};
749
+ int nnz = 0;
750
+ // Standard CSR format.
751
+ if (block_width_ == 1 && block_height_ == 1) {
752
+ for (int r = 0; r < rows_; ++r) {
753
+ for (int c = 0; c < cols_; ++c) {
754
+ if (mask[r * cols_ + c] == 1) {
755
+ nnz++;
756
+ col_indices->push_back(c);
757
+ weights_csr->push_back(WeightType(weights[r * cols_ + c]));
758
+ }
759
+ }
760
+ row_offsets.push_back(nnz);
761
+ }
762
+ } else if (block_width_ == 4 && block_height_ == 4) {
763
+ // Weights are stored contiguously for each block in this case.
764
+ for (int r = 0; r < reduced_rows_; ++r) {
765
+ for (int c = 0; c < reduced_cols_; ++c) {
766
+ if (mask[r * reduced_cols_ + c] == 1) {
767
+ col_indices->push_back(c);
768
+ nnz++;
769
+ for (int i = 0; i < block_height_; ++i) {
770
+ for (int j = 0; j < block_width_; ++j) {
771
+ int row_index = (block_height_ * r + i) * cols_;
772
+ int w_index = row_index + block_width_ * c + j;
773
+ WeightType weight = w_index < weights.size()
774
+ ? WeightType(weights[w_index])
775
+ : WeightType(0.0f);
776
+ weights_csr->push_back(weight);
777
+ }
778
+ }
779
+ }
780
+ }
781
+ row_offsets.push_back(nnz);
782
+ }
783
+ }
784
+ for (int i = 1; i < row_offsets.size(); ++i)
785
+ nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]);
786
+ }
787
+
788
+ // Returns the number of block rows per cache line. This is the minimum unit
789
+ // into which the calculation is broken for threads.
790
+ template <typename OutType>
791
+ int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const {
792
+ int line_size = kCacheLineSize;
793
+ if (override_cache_line_size >= 1) line_size = override_cache_line_size;
794
+ return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1);
795
+ }
796
+
797
+ int col_multiple_;
798
+ int rows_;
799
+ int cols_;
800
+ int reduced_rows_;
801
+ int reduced_cols_;
802
+ float sparsity_;
803
+ int block_width_;
804
+ int block_height_;
805
+ int num_threads_;
806
+ std::string name_;
807
+
808
+ CacheAlignedVector<WeightType> weights_;
809
+ CacheAlignedVector<DeltaType> col_deltas_;
810
+ CacheAlignedVector<int> nnz_per_row_;
811
+ // |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are
812
+ // always recalculated from serialized data.
813
+ CacheAlignedVector<DeltaType> rhs_indices_;
814
+ Matmul<WeightType, RhsType> matmul_;
815
+ ThreadBounds thread_bounds_;
816
+ static constexpr int kCacheLineSize = 64;
817
+ };
818
+
819
+ // Converts a sparse matrix represented with (|mask|, |weights|, |size|) into
820
+ // the CSR format, and returns that as a serialized string.
821
+ template <typename MaskType>
822
+ std::string ConvertDenseToSparseRepresentation_Int16Deltas(
823
+ const std::vector<MaskType>& mask, const std::vector<float>& weights,
824
+ const int rows, const int cols) {
825
+ MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(),
826
+ weights.data());
827
+ CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
828
+ sparse_masked_weights(masked_weights);
829
+ std::string buffer;
830
+ sparse_masked_weights.WriteToFlatBuffer(&buffer);
831
+ return buffer;
832
+ }
833
+
834
+ } // namespace csrblocksparse
835
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
sparse_matmul/layers/csrblocksparse_test.cc ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include <array>
16
+ #include <cstdint>
17
+ #include <tuple>
18
+ #include <vector>
19
+
20
+ // Placeholder for get runfiles header.
21
+ #include "absl/status/status.h"
22
+ #include "absl/strings/str_cat.h"
23
+ #include "absl/strings/string_view.h"
24
+ #include "absl/types/span.h"
25
+ #include "gtest/gtest.h"
26
+ #include "include/ghc/filesystem.hpp"
27
+ #include "sparse_matmul/compute/matmul.h"
28
+ #include "sparse_matmul/layers/utils.h"
29
+ #include "sparse_matmul/numerics/test_utils.h"
30
+ #include "sparse_matmul/os/coop_threads.h"
31
+
32
+ namespace csrblocksparse {
33
+ namespace {
34
+
35
+ inline constexpr absl::string_view kTestdataPath = "layers/testdata";
36
+
37
+ TEST(CSRBlockSparseMatrix, FlatBufferSerialization) {
38
+ const int kRows = 8;
39
+ const int kCols = 8;
40
+ std::vector<int> mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
41
+ 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
42
+ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
43
+ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1};
44
+ std::vector<float> values(kRows * kCols, 1.f);
45
+ values[1] = 2.f;
46
+ values[3] = 3.f;
47
+ values[36] = -1.f;
48
+ values[45] = -2.f;
49
+
50
+ csrblocksparse::CacheAlignedVector<float> bias(kRows);
51
+ csrblocksparse::CacheAlignedVector<float> rhs(kCols);
52
+ csrblocksparse::CacheAlignedVector<float> out_ref(kRows);
53
+ csrblocksparse::CacheAlignedVector<float> out_test(kRows);
54
+
55
+ bias.FillZero();
56
+ rhs.FillOnes();
57
+
58
+ csrblocksparse::MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(),
59
+ values.data());
60
+
61
+ matrix.SpMM_bias(rhs, bias, &out_ref);
62
+
63
+ csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
64
+ block_sparse_matrix(matrix);
65
+
66
+ std::string buffer;
67
+ std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer);
68
+
69
+ csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
70
+ new_block_sparse_matrix(reinterpret_cast<const uint8_t*>(buffer.c_str()),
71
+ num_bytes);
72
+
73
+ new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test);
74
+
75
+ CheckResult(out_ref, out_test, kCols);
76
+ }
77
+
78
+ template <typename ComputeType, typename RhsType, typename OutType>
79
+ void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height,
80
+ int block_width, float sparsity,
81
+ bool use_relu = false, int num_threads = 1,
82
+ int fatness = 1, bool test_matmul = false) {
83
+ using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
84
+ MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
85
+ block_width);
86
+ matrix.CastWeights<ComputeType>();
87
+ FatCacheAlignedVector<RhsType> rhs(cols, fatness);
88
+ CacheAlignedVector<BiasType> bias(rows);
89
+ FatCacheAlignedVector<OutType> out(rows, fatness);
90
+
91
+ bias.FillRandom();
92
+ rhs.FillRandom();
93
+ out.FillZero();
94
+ FatCacheAlignedVector<OutType> out_reference = out;
95
+
96
+ matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
97
+
98
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
99
+
100
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
101
+ std::move(sparse_matrix), std::move(bias));
102
+ num_threads = sparse_linear_layer.PrepareForThreads(num_threads);
103
+
104
+ // Checks that the result of applying each thread's portion serially is
105
+ // correct.
106
+ for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
107
+ sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id);
108
+ }
109
+
110
+ CheckResult(out_reference, out, sparse_linear_layer.cols());
111
+
112
+ if (test_matmul) {
113
+ for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
114
+ sparse_linear_layer.MatVec(rhs, use_relu, thread_id,
115
+ /*replicas=*/1, /*output_stride=*/0, &out);
116
+ }
117
+
118
+ CheckResult(out_reference, out, sparse_linear_layer.cols());
119
+ }
120
+ }
121
+
122
+ // Does:
123
+ // y = Ax + b;
124
+ // x = Ay + b;
125
+ // y = Ax + b;
126
+ //
127
+ // to make sure that dependent multiplies are correct.
128
+ template <typename ComputeType, typename RhsType, typename OutType>
129
+ void ThreadBody(
130
+ SpinBarrier* spin_barrier, int tid,
131
+ const SparseLinearLayer<ComputeType, RhsType>& sparse_linear_layer,
132
+ FatCacheAlignedVector<RhsType>* rhs, FatCacheAlignedVector<OutType>* out,
133
+ bool use_relu) {
134
+ sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
135
+ spin_barrier->barrier();
136
+ sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid);
137
+ spin_barrier->barrier();
138
+ sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
139
+ }
140
+
141
+ template <typename ComputeType, typename RhsType, typename OutType>
142
+ void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height,
143
+ int block_width, float sparsity,
144
+ bool use_relu = false,
145
+ int num_threads = 1,
146
+ int fatness = 1) {
147
+ typedef typename TypeOfProduct<ComputeType, RhsType>::type BiasType;
148
+ CHECK(rows == cols);
149
+ MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
150
+ block_width);
151
+ matrix.CastWeights<ComputeType>();
152
+ FatCacheAlignedVector<RhsType> rhs(cols, fatness);
153
+ FatCacheAlignedVector<RhsType> rhs_mt(cols, fatness);
154
+ CacheAlignedVector<BiasType> bias(rows);
155
+ FatCacheAlignedVector<OutType> out(rows, fatness);
156
+
157
+ bias.FillOnes();
158
+ rhs.FillOnes();
159
+ rhs_mt.FillOnes();
160
+ out.FillZero();
161
+ FatCacheAlignedVector<OutType> out_reference = out;
162
+
163
+ matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
164
+ matrix.SpMM_bias(out_reference, bias, &rhs, use_relu);
165
+ matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
166
+
167
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
168
+
169
+ num_threads = sparse_matrix.PrepareForThreads(num_threads,
170
+ /*cache_line_size=*/1);
171
+
172
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
173
+ std::move(sparse_matrix), std::move(bias));
174
+
175
+ csrblocksparse::LaunchOnThreadsWithBarrier(
176
+ num_threads, ThreadBody<ComputeType, RhsType, OutType>,
177
+ sparse_linear_layer, &rhs_mt, &out, use_relu);
178
+
179
+ CheckResult(out_reference, out, cols);
180
+ }
181
+
182
+ } // namespace
183
+
184
+ TEST(MaskedSparseCorrectness, HandCoded) {
185
+ const int kRows = 8;
186
+ const int kCols = 8;
187
+ // clang-format off
188
+ std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
189
+ 0, 1, 0, 1, 0, 1, 0, 1,
190
+ 1, 0, 0, 1, 1, 1, 1, 0,
191
+ 0, 0, 0, 0, 0, 0, 0, 0,
192
+ 1, 1, 1, 1, 1, 1, 1, 1,
193
+ 0, 0, 0, 0, 1, 1, 0, 0,
194
+ 1, 1, 0, 0, 1, 1, 0, 0,
195
+ 1, 0, 0, 0, 0, 1, 0, 1};
196
+ // clang-format on
197
+ std::vector<float> values(kRows * kCols, 1.f);
198
+
199
+ std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
200
+
201
+ MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
202
+ CacheAlignedVector<float> rhs(kCols);
203
+ CacheAlignedVector<float> bias(kRows);
204
+ CacheAlignedVector<float> out(kRows);
205
+
206
+ bias.FillOnes();
207
+ rhs.FillOnes();
208
+ out.FillZero();
209
+
210
+ MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
211
+ std::move(bias));
212
+
213
+ masked_linear_layer.SpMM_bias(rhs, &out);
214
+
215
+ for (int i = 0; i < kRows; ++i) {
216
+ EXPECT_EQ(answer[i], out[i]);
217
+ }
218
+ }
219
+
220
+ TEST(MaskedSparseCorrectness, HandCodedFatVector) {
221
+ const int kRows = 8;
222
+ const int kCols = 8;
223
+ // clang-format off
224
+ std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
225
+ 0, 1, 0, 1, 0, 1, 0, 1,
226
+ 1, 0, 0, 1, 1, 1, 1, 0,
227
+ 0, 0, 0, 0, 0, 0, 0, 0,
228
+ 1, 1, 1, 1, 1, 1, 1, 1,
229
+ 0, 0, 0, 0, 1, 1, 0, 0,
230
+ 1, 1, 0, 0, 1, 1, 0, 0,
231
+ 1, 0, 0, 0, 0, 1, 0, 1};
232
+ // clang-format on
233
+
234
+ std::vector<float> values(kRows * kCols, 1.f);
235
+ std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
236
+
237
+ MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
238
+ const int kMaxWidth = 5;
239
+ for (int width = 5; width <= kMaxWidth; ++width) {
240
+ FatCacheAlignedVector<float> rhs(kCols, width);
241
+ CacheAlignedVector<float> bias(kRows);
242
+ FatCacheAlignedVector<float> out(kRows, width);
243
+
244
+ bias.FillOnes();
245
+ rhs.FillOnes();
246
+ out.FillZero();
247
+
248
+ MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
249
+ std::move(bias));
250
+
251
+ masked_linear_layer.SpMM_bias(rhs, &out);
252
+
253
+ for (int i = 0; i < kRows; ++i) {
254
+ for (int width = 0; width < kMaxWidth; ++width) {
255
+ EXPECT_EQ(answer[i], out[i + width * kRows]);
256
+ }
257
+ }
258
+ }
259
+ }
260
+
261
+ TEST(CsrBlockSparseMatrix, HandCodedMultiThread) {
262
+ const int kRows = 8;
263
+ const int kCols = 8;
264
+ // clang-format off
265
+ std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
266
+ 0, 1, 0, 1, 0, 1, 0, 1,
267
+ 1, 0, 0, 1, 1, 1, 1, 0,
268
+ 0, 0, 0, 0, 0, 0, 0, 0,
269
+ 1, 1, 1, 1, 1, 1, 1, 1,
270
+ 0, 0, 0, 0, 1, 1, 0, 0,
271
+ 1, 1, 0, 0, 1, 1, 0, 0,
272
+ 1, 0, 0, 0, 0, 1, 0, 1};
273
+ // clang-format on
274
+ std::vector<float> values(kRows * kCols, 1.f);
275
+
276
+ std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
277
+
278
+ MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
279
+ CacheAlignedVector<float> rhs(kCols);
280
+ CacheAlignedVector<float> bias(kRows);
281
+ CacheAlignedVector<float> out(kRows);
282
+
283
+ bias.FillOnes();
284
+ rhs.FillOnes();
285
+ out.FillZero();
286
+
287
+ CacheAlignedVector<float> bias_csr = bias;
288
+
289
+ CsrBlockSparseMatrix<bfloat16, float> sparse_matrix(matrix);
290
+
291
+ MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
292
+ std::move(bias));
293
+
294
+ masked_linear_layer.SpMM_bias(rhs, &out);
295
+
296
+ SparseLinearLayer<bfloat16, float> sparse_linear_layer(
297
+ std::move(sparse_matrix), std::move(bias_csr));
298
+ sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1);
299
+
300
+ CacheAlignedVector<float> out_tmp(kRows);
301
+ const bool kUseRelu = false;
302
+ sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0);
303
+ sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1);
304
+
305
+ for (int i = 0; i < kRows; ++i) {
306
+ EXPECT_EQ(answer[i], out_tmp[i]);
307
+ }
308
+ }
309
+
310
+ TEST(TestCasts, TestBfloat16) {
311
+ const int kRows = 1000;
312
+ const int kCols = 100;
313
+ const float kSparsity = 0.f;
314
+
315
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
316
+ MaskedSparseMatrix<float> matrix_bfloat16(kRows, kCols, matrix.mask().data(),
317
+ matrix.values().data());
318
+
319
+ matrix_bfloat16.CastWeights<bfloat16>();
320
+
321
+ CheckResult(matrix.values(), matrix_bfloat16.values(), kCols);
322
+ }
323
+
324
+ TEST(TestCasts, TestFP16) {
325
+ const int kRows = 1000;
326
+ const int kCols = 100;
327
+ const float kSparsity = 0.f;
328
+
329
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
330
+ #if !defined __arm__ && !defined __aarch64__
331
+ // Conversion doesn't handle denormals, so flush denormals to zero first.
332
+ for (int i = 0; i < matrix.values().size(); ++i) {
333
+ if (matrix.data()[i] < 1. / static_cast<float>(1 << 14))
334
+ matrix.data()[i] = 0.f;
335
+ }
336
+ #endif
337
+ MaskedSparseMatrix<float> matrix_fp16(kRows, kCols, matrix.mask().data(),
338
+ matrix.values().data());
339
+
340
+ matrix_fp16.CastWeights<csrblocksparse::fp16>();
341
+
342
+ CheckResult(matrix.values(), matrix_fp16.values(), kCols);
343
+ }
344
+
345
+ TEST(TestCasts, TestFixed16) {
346
+ const int kRows = 100000;
347
+ const int kCols = 1;
348
+ const float kSparsity = 0.f;
349
+
350
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
351
+
352
+ // Relative error for fixed point is high near 0.
353
+ for (int i = 0; i < matrix.values().size(); ++i) {
354
+ // 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16
355
+ // == 3e-5. 3e-5 / .013 / 2 = 1.1e-3.
356
+ if (std::abs(matrix.data()[i]) < 1.1e-3) {
357
+ matrix.data()[i] = 0.f;
358
+ }
359
+ }
360
+
361
+ MaskedSparseMatrix<float> matrix_fixed16 = matrix;
362
+
363
+ matrix_fixed16.CastWeights<csrblocksparse::fixed16</*ExponentBits=*/0>>();
364
+
365
+ CheckResult(matrix.values(), matrix_fixed16.values(), kCols);
366
+ }
367
+
368
+ TEST(TestCasts, TestFixed32) {
369
+ const int kRows = 100000;
370
+ const int kCols = 1;
371
+ const float kSparsity = 0.f;
372
+
373
+ MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
374
+ MaskedSparseMatrix<float> matrix_fixed32 = matrix;
375
+
376
+ matrix_fixed32.CastWeights<csrblocksparse::fixed32</*ExponentBits=*/0>>();
377
+
378
+ CheckResult(matrix.values(), matrix_fixed32.values(), kCols);
379
+ }
380
+
381
+ template <typename ComputeType, typename RhsType, typename OutType>
382
+ void TestSpMM(int block_width, int block_height, int fatness,
383
+ bool test_matmul = false) {
384
+ std::array<bool, 2> use_relu = {false, true};
385
+ std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
386
+ std::vector<std::pair<int, int>> sizes = {{8, 8}, {128, 128}, {128, 64},
387
+ {256, 192}, {512, 512}, {1024, 512},
388
+ {384, 384}, {512, 384}};
389
+ for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) {
390
+ for (const auto& relu : use_relu) {
391
+ for (const auto& sparsity : sparsity_levels) {
392
+ for (const auto& size : sizes) {
393
+ int rows, cols;
394
+ std::tie(rows, cols) = size;
395
+ CorrectnessCheckBlockSpMM<ComputeType, RhsType, OutType>(
396
+ rows, cols, block_height, block_width, sparsity, relu,
397
+ num_threads, fatness, test_matmul);
398
+ }
399
+ }
400
+ }
401
+ }
402
+ }
403
+
404
+ template <typename ComputeType, typename RhsType, typename OutType>
405
+ void TestSpMM_MultiThread(int block_width, int block_height, int fatness) {
406
+ std::array<bool, 2> use_relu = {false, true};
407
+ std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
408
+ std::vector<std::pair<int, int>> sizes = {
409
+ {48, 48}, {128, 128}, {512, 512}, {384, 384}};
410
+ for (int num_threads = 1; num_threads < 5; ++num_threads) {
411
+ for (const auto& relu : use_relu) {
412
+ for (const auto& sparsity : sparsity_levels) {
413
+ for (const auto& size : sizes) {
414
+ int rows, cols;
415
+ std::tie(rows, cols) = size;
416
+ CorrectnessCheckBlockSpMM_MultiThread<ComputeType, RhsType, OutType>(
417
+ rows, cols, block_height, block_width, sparsity, relu,
418
+ num_threads, fatness);
419
+ }
420
+ }
421
+ }
422
+ }
423
+ }
424
+
425
+ template <typename DataType>
426
+ void TestSumVectors(int start = 0, int end = -1, int size = 6) {
427
+ std::vector<DataType> values;
428
+ std::vector<DataType> answer;
429
+
430
+ for (int i = 1; i < size + 1; ++i) {
431
+ const float x = static_cast<float>(i);
432
+ values.push_back(static_cast<DataType>(x));
433
+ answer.push_back(static_cast<DataType>(x * 2));
434
+ }
435
+
436
+ if (end == -1) {
437
+ end = values.size();
438
+ }
439
+
440
+ csrblocksparse::CacheAlignedVector<DataType> result(values.size());
441
+ csrblocksparse::CacheAlignedVector<DataType> values_aligned(values);
442
+ detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(),
443
+ result.data());
444
+ for (int i = start; i < end; ++i) {
445
+ EXPECT_EQ(static_cast<float>(answer[i]), static_cast<float>(result[i]));
446
+ }
447
+ }
448
+
449
+ TEST(CsrBlockSparseMatrix, SumVectors_Generic) {
450
+ TestSumVectors<float>();
451
+ TestSumVectors<float>(1);
452
+ TestSumVectors<float>(1, 4);
453
+ }
454
+
455
+ TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) {
456
+ TestSumVectors<csrblocksparse::bfloat16>();
457
+ TestSumVectors<csrblocksparse::bfloat16>(1);
458
+ TestSumVectors<csrblocksparse::bfloat16>(1, 4);
459
+ }
460
+
461
+ // For SIMD-optimized SumVectors, the memory of the vector should be at least
462
+ // |kSIMDWidth * sizeof(float)| long, and the start position has to be an
463
+ // aligned memory location. So setting |size| to be 100 to be safe and
464
+ // |start| to be 0 (|start| == 1 is not aligned).
465
+ TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) {
466
+ TestSumVectors<csrblocksparse::fixed16<8>>(0, -1, 100);
467
+ TestSumVectors<csrblocksparse::fixed16<8>>(0, 4, 100);
468
+ }
469
+
470
+ TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) {
471
+ TestSumVectors<csrblocksparse::fixed32<11>>(0, -1, 100);
472
+ TestSumVectors<csrblocksparse::fixed32<11>>(0, 4, 100);
473
+ }
474
+
475
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) {
476
+ TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/4,
477
+ /*block_height=*/4,
478
+ /*fatness=*/7);
479
+ }
480
+
481
+ // This actually uses multiple threads, and uses the output as the input for
482
+ // multiple steps to test that synchronization and memory visibility is
483
+ // working correctly.Requires square matrices.
484
+ TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) {
485
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
486
+ /*block_width=*/4,
487
+ /*block_height=*/4,
488
+ /*fatness=*/1);
489
+ }
490
+
491
+ TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) {
492
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
493
+ /*block_width=*/4,
494
+ /*block_height=*/4,
495
+ /*fatness=*/7);
496
+ }
497
+
498
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) {
499
+ TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
500
+ /*block_height=*/1,
501
+ /*fatness=*/1);
502
+ }
503
+
504
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) {
505
+ TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
506
+ /*block_height=*/1,
507
+ /*fatness=*/7);
508
+ }
509
+
510
+ // This actually uses multiple threads, and uses the output as the input for
511
+ // multiple steps to test that synchronization and memory visibility is
512
+ // working correctly.Requires square matrices.
513
+ TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) {
514
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
515
+ /*block_width=*/1,
516
+ /*block_height=*/1,
517
+ /*fatness=*/1);
518
+ }
519
+
520
+ TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) {
521
+ TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
522
+ /*block_width=*/1,
523
+ /*block_height=*/1,
524
+ /*fatness=*/7);
525
+ }
526
+
527
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) {
528
+ TestSpMM<float, float, float>(/*block_width=*/4,
529
+ /*block_height=*/4,
530
+ /*fatness=*/1,
531
+ /*test_matmul=*/true);
532
+ }
533
+
534
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) {
535
+ TestSpMM<float, float, float>(/*block_width=*/4,
536
+ /*block_height=*/4,
537
+ /*fatness=*/7);
538
+ }
539
+
540
+ // This actually uses multiple threads, and uses the output as the input for
541
+ // multiple steps to test that synchronization and memory visibility is
542
+ // working correctly.Requires square matrices.
543
+ TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) {
544
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
545
+ /*block_height=*/4,
546
+ /*fatness=*/1);
547
+ }
548
+
549
+ TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) {
550
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
551
+ /*block_height=*/4,
552
+ /*fatness=*/7);
553
+ }
554
+
555
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) {
556
+ TestSpMM<float, float, float>(/*block_width=*/1,
557
+ /*block_height=*/1,
558
+ /*fatness=*/1);
559
+ }
560
+
561
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) {
562
+ TestSpMM<float, float, float>(/*block_width=*/1,
563
+ /*block_height=*/1,
564
+ /*fatness=*/7);
565
+ }
566
+
567
+ // This actually uses multiple threads, and uses the output as the input for
568
+ // multiple steps to test that synchronization and memory visibility is
569
+ // working correctly.Requires square matrices.
570
+ TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) {
571
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
572
+ /*block_height=*/1,
573
+ /*fatness=*/1);
574
+ }
575
+
576
+ TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) {
577
+ TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
578
+ /*block_height=*/1,
579
+ /*fatness=*/7);
580
+ }
581
+
582
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) {
583
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
584
+ typename csrblocksparse::TypeOfProduct<
585
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
586
+ /*block_width=*/4,
587
+ /*block_height=*/4,
588
+ /*fatness=*/1,
589
+ /*test_matmul=*/true);
590
+ }
591
+
592
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) {
593
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
594
+ typename csrblocksparse::TypeOfProduct<
595
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
596
+ /*block_width=*/4,
597
+ /*block_height=*/4,
598
+ /*fatness=*/7);
599
+ }
600
+
601
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) {
602
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
603
+ typename csrblocksparse::TypeOfProduct<
604
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
605
+ /*block_width=*/1,
606
+ /*block_height=*/1,
607
+ /*fatness=*/1);
608
+ }
609
+
610
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) {
611
+ TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
612
+ typename csrblocksparse::TypeOfProduct<
613
+ csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
614
+ /*block_width=*/1,
615
+ /*block_height=*/1,
616
+ /*fatness=*/7);
617
+ }
618
+
619
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) {
620
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
621
+ csrblocksparse::fixed16<8>>(
622
+ /*block_width=*/4,
623
+ /*block_height=*/4,
624
+ /*fatness=*/1,
625
+ /*test_matmul=*/true);
626
+ }
627
+
628
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) {
629
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
630
+ csrblocksparse::fixed16<8>>(
631
+ /*block_width=*/4,
632
+ /*block_height=*/4,
633
+ /*fatness=*/7);
634
+ }
635
+
636
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) {
637
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
638
+ csrblocksparse::fixed16<8>>(
639
+ /*block_width=*/1,
640
+ /*block_height=*/1,
641
+ /*fatness=*/1);
642
+ }
643
+
644
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) {
645
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
646
+ csrblocksparse::fixed16<8>>(
647
+ /*block_width=*/1,
648
+ /*block_height=*/1,
649
+ /*fatness=*/7);
650
+ }
651
+
652
+ TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) {
653
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
654
+ csrblocksparse::fixed32<13>>(
655
+ /*block_width=*/4,
656
+ /*block_height=*/4,
657
+ /*fatness=*/1,
658
+ /*test_matmul=*/true);
659
+ }
660
+
661
+ TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) {
662
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
663
+ csrblocksparse::fixed32<13>>(
664
+ /*block_width=*/4,
665
+ /*block_height=*/4,
666
+ /*fatness=*/7);
667
+ }
668
+
669
+ TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) {
670
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
671
+ csrblocksparse::fixed32<13>>(
672
+ /*block_width=*/1,
673
+ /*block_height=*/1,
674
+ /*fatness=*/1);
675
+ }
676
+
677
+ TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) {
678
+ TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
679
+ csrblocksparse::fixed32<13>>(
680
+ /*block_width=*/1,
681
+ /*block_height=*/1,
682
+ /*fatness=*/7);
683
+ }
684
+
685
+ TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) {
686
+ MaskedSparseMatrix<float> matrix(/*rows=*/256, /*cols=*/256,
687
+ /*sparsity=*/0.9, /*block_height=*/4,
688
+ /*block_width=*/4);
689
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
690
+ CacheAlignedVector<int16_t> copy_indices = sparse_matrix.rhs_indices();
691
+ sparse_matrix.ComputeColDeltas();
692
+ sparse_matrix.ComputeRHSIndices();
693
+ // They get padded when created, so the newer one could be bigger.
694
+ EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size());
695
+ for (int i = 0; i < copy_indices.size(); ++i) {
696
+ EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i;
697
+ }
698
+ }
699
+
700
+ // Tests that a Layer that is split into 2 by columns (inputs) computes the same
701
+ // result as the original layer.
702
+ TEST(CsrBlockSparseMatrix, SplitByCol) {
703
+ int kRows = 1024;
704
+ int kCols = 1024;
705
+ MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
706
+ /*block_width=*/4);
707
+ FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
708
+ CacheAlignedVector<float> bias(kRows);
709
+ FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
710
+ FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
711
+
712
+ bias.FillRandom();
713
+ rhs.FillRandom();
714
+ out1.FillZero();
715
+ out2.FillZero();
716
+ FatCacheAlignedVector<float> out_reference = out1;
717
+
718
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
719
+
720
+ SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
721
+ std::move(bias));
722
+ sparse_linear_layer.PrepareForThreads(1);
723
+ sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
724
+ /*tid=*/0);
725
+ // Split the layer into 2 parts.
726
+ SparseLinearLayer<float, float> part1, part2;
727
+ sparse_linear_layer.SplitInputs(&part1, &part2);
728
+ part1.PrepareForThreads(1);
729
+ part2.PrepareForThreads(1);
730
+ EXPECT_EQ(kRows, part1.rows());
731
+ EXPECT_EQ(kCols / 2, part1.cols());
732
+ EXPECT_EQ(kRows, part2.rows());
733
+ EXPECT_EQ(kCols / 2, part2.cols());
734
+ MutableVectorView<float> rhs1(&rhs, 0, kCols / 2);
735
+ MutableVectorView<float> rhs2(&rhs, kCols / 2, kCols / 2);
736
+ for (int i = 0; i < kCols / 2; ++i) {
737
+ EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]);
738
+ EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]);
739
+ }
740
+ part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0);
741
+ part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0);
742
+ // Check that out1 + out2 = out_reference.
743
+ for (int i = 0; i < kRows; ++i) {
744
+ EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5)
745
+ << " i=" << i << " out1=" << out1[i] << " out2=" << out2[i];
746
+ }
747
+ }
748
+ // Tests that a Layer that is split into 2 by rows (outputs) computes the same
749
+ // result as the original layer.
750
+ TEST(CsrBlockSparseMatrix, SplitByRow) {
751
+ int kRows = 1024;
752
+ int kCols = 1024;
753
+ MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
754
+ /*block_width=*/4);
755
+ FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
756
+ CacheAlignedVector<float> bias(kRows);
757
+ FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
758
+ FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
759
+
760
+ bias.FillRandom();
761
+ rhs.FillRandom();
762
+ out1.FillZero();
763
+ out2.FillZero();
764
+ FatCacheAlignedVector<float> out_reference = out1;
765
+
766
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
767
+
768
+ SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
769
+ std::move(bias));
770
+ sparse_linear_layer.PrepareForThreads(1);
771
+ sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
772
+ /*tid=*/0);
773
+ // Split the layer into 2 parts.
774
+ SparseLinearLayer<float, float> part1, part2;
775
+ sparse_linear_layer.SplitOutputs(&part1, &part2);
776
+ part1.PrepareForThreads(1);
777
+ part2.PrepareForThreads(1);
778
+ EXPECT_EQ(kRows / 2, part1.rows());
779
+ EXPECT_EQ(kCols, part1.cols());
780
+ EXPECT_EQ(kRows / 2, part2.rows());
781
+ EXPECT_EQ(kCols, part2.cols());
782
+ MutableVectorView<float> out2a(&out2, 0, kRows / 2);
783
+ MutableVectorView<float> out2b(&out2, kRows / 2, kRows / 2);
784
+ part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0);
785
+ part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0);
786
+ // Check that out2 = out_reference.
787
+ for (int i = 0; i < kRows; ++i) {
788
+ EXPECT_NEAR(out_reference[i], out2[i], 2e-5)
789
+ << " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i];
790
+ }
791
+ }
792
+
793
+ TEST(CsrBlockSparseMatrix, MutableVectorView) {
794
+ const int kRows = 1024;
795
+ const int kCols = 1024;
796
+ const int kFatness = 2;
797
+
798
+ std::vector<float> values(kRows * kCols, 1.f);
799
+ std::vector<int> mask(kRows * kCols);
800
+ for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2;
801
+
802
+ auto masked_matrix =
803
+ MaskedSparseMatrix<float>(kRows, kCols, mask.data(), values.data());
804
+ auto sparse_matrix = CsrBlockSparseMatrix<bfloat16, float>(masked_matrix);
805
+ FatCacheAlignedVector<float> x(kCols, kFatness);
806
+ x.FillOnes();
807
+
808
+ CacheAlignedVector<float> bias(kRows);
809
+ bias.FillZero();
810
+
811
+ // First check that we can use spans as output. Split a multiplication
812
+ // into upper and lower halves times the full vector:
813
+ // --------------- x t
814
+ // | | x t
815
+ // | | x t
816
+ // --------------- =
817
+ // | | x b
818
+ // | | x b
819
+ // --------------- x b
820
+
821
+ FatCacheAlignedVector<float> out(kRows, kFatness);
822
+ FatCacheAlignedVector<float> out_view(kRows, kFatness);
823
+
824
+ MutableVectorView<float> out_view_top(&out_view, 0, kRows / 2);
825
+ MutableVectorView<float> out_view_bottom(&out_view, kRows / 2, kRows / 2);
826
+
827
+ sparse_matrix.SpMM_bias(x, bias, &out);
828
+
829
+ auto masked_matrix_top =
830
+ MaskedSparseMatrix<float>(kRows / 2, kCols, mask.data(), values.data());
831
+ auto masked_matrix_bottom = MaskedSparseMatrix<float>(
832
+ kRows / 2, kCols, mask.data() + kRows * kCols / 2,
833
+ values.data() + kRows * kCols / 2);
834
+ auto sparse_matrix_top =
835
+ CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_top);
836
+ auto sparse_matrix_bottom =
837
+ CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_bottom);
838
+
839
+ sparse_matrix_top.SpMM_bias(x, bias, &out_view_top);
840
+ sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom);
841
+
842
+ CheckResult(out, out_view, kCols);
843
+
844
+ // Check that we can use a span as an input vector. Multiply upper left
845
+ // portion of the matrix by the top half of the vector.
846
+ // ---------------
847
+ // |oooooo | x q
848
+ // |oooooo | x q
849
+ // | | =
850
+ // | |
851
+ // ---------------
852
+
853
+ auto masked_matrix_quarter = MaskedSparseMatrix<float>(
854
+ kRows / 2, kCols / 2, mask.data(), values.data());
855
+ auto sparse_matrix_quarter =
856
+ CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_quarter);
857
+
858
+ MutableVectorView<float> x_top(&x, 0, kCols / 2);
859
+ FatCacheAlignedVector<float> out_correct(kRows / 2, /*cols=*/2);
860
+
861
+ for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f;
862
+
863
+ MutableVectorView<float> bias_top(&bias, 0, kRows / 2);
864
+ FatCacheAlignedVector<float> out_quarter(kRows / 2, kFatness);
865
+
866
+ sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter);
867
+
868
+ CheckResult(out_correct, out_quarter, kCols / 2);
869
+ }
870
+
871
+ namespace {
872
+
873
+ bool skip_test(const absl::Status& status, absl::string_view msg) {
874
+ if (!status.ok()) {
875
+ LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status;
876
+ return true;
877
+ }
878
+
879
+ return false;
880
+ }
881
+
882
+ } // namespace
883
+
884
+ TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) {
885
+ std::vector<std::string> names = {
886
+ "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
887
+ "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
888
+ "768_512_95_4x4_finelogit_", "lyra_conv1d_"};
889
+ const std::string kPath =
890
+ #if defined __arm__ || defined __aarch64__
891
+ "/data/local/tmp/";
892
+ #else
893
+ (ghc::filesystem::current_path() / kTestdataPath).string();
894
+ #endif
895
+ for (auto& layer_name : names) {
896
+ SparseLinearLayer<bfloat16, float> sparse_linear_layer;
897
+ auto status = LoadSparseLayer<bfloat16, float>(layer_name, /*zipped=*/true,
898
+ &sparse_linear_layer, kPath);
899
+ // If the files don't exist on the device we're running on, just skip this
900
+ // test and log that it was skipped.
901
+ if (skip_test(status, layer_name)) return;
902
+
903
+ int rows = sparse_linear_layer.rows();
904
+ int cols = sparse_linear_layer.cols();
905
+
906
+ MaskedLinearLayer<float> masked_linear_layer;
907
+ status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
908
+ &masked_linear_layer, kPath);
909
+ if (skip_test(status, layer_name)) return;
910
+ masked_linear_layer.CastWeights<csrblocksparse::bfloat16>();
911
+
912
+ CacheAlignedVector<float> rhs(cols);
913
+ CacheAlignedVector<float> out_ref(rows);
914
+ CacheAlignedVector<float> out_spmv(rows);
915
+
916
+ rhs.FillRandom();
917
+ out_ref.FillZero();
918
+ out_spmv.FillZero();
919
+
920
+ std::array<bool, 2> use_relus = {false, true};
921
+ for (bool use_relu : use_relus) {
922
+ masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
923
+ sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
924
+
925
+ CheckResult(out_ref, out_spmv, cols);
926
+ }
927
+ }
928
+ }
929
+
930
+ TEST(CsrBlockSparseMatrix, ModelMatrices_float) {
931
+ std::vector<std::string> names = {
932
+ "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
933
+ "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
934
+ "768_512_95_4x4_finelogit_", "lyra_conv1d_"};
935
+ const std::string kPath =
936
+ #if defined __arm__ || defined __aarch64__
937
+ "/data/local/tmp/";
938
+ #else
939
+ (ghc::filesystem::current_path() / kTestdataPath).string();
940
+ #endif
941
+ for (auto& layer_name : names) {
942
+ SparseLinearLayer<float, float> sparse_linear_layer;
943
+ auto status = LoadSparseLayer<float, float>(layer_name, /*zipped=*/true,
944
+ &sparse_linear_layer, kPath);
945
+ // If the files don't exist on the device we're running on, just skip this
946
+ // test and log that it was skipped.
947
+ if (skip_test(status, layer_name)) return;
948
+
949
+ int rows = sparse_linear_layer.rows();
950
+ int cols = sparse_linear_layer.cols();
951
+
952
+ MaskedLinearLayer<float> masked_linear_layer;
953
+ status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
954
+ &masked_linear_layer, kPath);
955
+ if (skip_test(status, layer_name)) return;
956
+
957
+ CacheAlignedVector<float> rhs(cols);
958
+ CacheAlignedVector<float> out_ref(rows);
959
+ CacheAlignedVector<float> out_spmv(rows);
960
+
961
+ rhs.FillRandom();
962
+ out_ref.FillZero();
963
+ out_spmv.FillZero();
964
+
965
+ std::array<bool, 2> use_relus = {false, true};
966
+ for (bool use_relu : use_relus) {
967
+ masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
968
+ sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
969
+
970
+ CheckResult(out_ref, out_spmv, cols);
971
+ }
972
+ }
973
+ }
974
+
975
+ #undef SKIP_TEST
976
+
977
+ } // namespace csrblocksparse
sparse_matmul/layers/errno_mapping.cc ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/layers/errno_mapping.h"
16
+
17
+ #include <string>
18
+
19
+ #include "absl/strings/str_cat.h"
20
+
21
+ namespace csrblocksparse {
22
+
23
+ namespace {
24
+
25
+ absl::StatusCode ErrnoToCode(int error_number) {
26
+ switch (error_number) {
27
+ case 0:
28
+ return absl::StatusCode::kOk;
29
+ case EINVAL: // Invalid argument
30
+ case ENAMETOOLONG: // Filename too long
31
+ case E2BIG: // Argument list too long
32
+ case EDESTADDRREQ: // Destination address required
33
+ case EDOM: // Mathematics argument out of domain of function
34
+ case EFAULT: // Bad address
35
+ case EILSEQ: // Illegal byte sequence
36
+ case ENOPROTOOPT: // Protocol not available
37
+ case ENOSTR: // Not a STREAM
38
+ case ENOTSOCK: // Not a socket
39
+ case ENOTTY: // Inappropriate I/O control operation
40
+ case EPROTOTYPE: // Protocol wrong type for socket
41
+ case ESPIPE: // Invalid seek
42
+ return absl::StatusCode::kInvalidArgument;
43
+ case ETIMEDOUT: // Connection timed out
44
+ case ETIME: // Timer expired
45
+ return absl::StatusCode::kDeadlineExceeded;
46
+ case ENODEV: // No such device
47
+ case ENOENT: // No such file or directory
48
+ #ifdef ENOMEDIUM
49
+ case ENOMEDIUM: // No medium found
50
+ #endif
51
+ case ENXIO: // No such device or address
52
+ case ESRCH: // No such process
53
+ return absl::StatusCode::kNotFound;
54
+ case EEXIST: // File exists
55
+ case EADDRNOTAVAIL: // Address not available
56
+ case EALREADY: // Connection already in progress
57
+ #ifdef ENOTUNIQ
58
+ case ENOTUNIQ: // Name not unique on network
59
+ #endif
60
+ return absl::StatusCode::kAlreadyExists;
61
+ case EPERM: // Operation not permitted
62
+ case EACCES: // Permission denied
63
+ #ifdef ENOKEY
64
+ case ENOKEY: // Required key not available
65
+ #endif
66
+ case EROFS: // Read only file system
67
+ return absl::StatusCode::kPermissionDenied;
68
+ case ENOTEMPTY: // Directory not empty
69
+ case EISDIR: // Is a directory
70
+ case ENOTDIR: // Not a directory
71
+ case EADDRINUSE: // Address already in use
72
+ case EBADF: // Invalid file descriptor
73
+ #ifdef EBADFD
74
+ case EBADFD: // File descriptor in bad state
75
+ #endif
76
+ case EBUSY: // Device or resource busy
77
+ case ECHILD: // No child processes
78
+ case EISCONN: // Socket is connected
79
+ #ifdef EISNAM
80
+ case EISNAM: // Is a named type file
81
+ #endif
82
+ #ifdef ENOTBLK
83
+ case ENOTBLK: // Block device required
84
+ #endif
85
+ case ENOTCONN: // The socket is not connected
86
+ case EPIPE: // Broken pipe
87
+ #ifdef ESHUTDOWN
88
+ case ESHUTDOWN: // Cannot send after transport endpoint shutdown
89
+ #endif
90
+ case ETXTBSY: // Text file busy
91
+ #ifdef EUNATCH
92
+ case EUNATCH: // Protocol driver not attached
93
+ #endif
94
+ return absl::StatusCode::kFailedPrecondition;
95
+ case ENOSPC: // No space left on device
96
+ #ifdef EDQUOT
97
+ case EDQUOT: // Disk quota exceeded
98
+ #endif
99
+ case EMFILE: // Too many open files
100
+ case EMLINK: // Too many links
101
+ case ENFILE: // Too many open files in system
102
+ case ENOBUFS: // No buffer space available
103
+ case ENODATA: // No message is available on the STREAM read queue
104
+ case ENOMEM: // Not enough space
105
+ case ENOSR: // No STREAM resources
106
+ #ifdef EUSERS
107
+ case EUSERS: // Too many users
108
+ #endif
109
+ return absl::StatusCode::kResourceExhausted;
110
+ #ifdef ECHRNG
111
+ case ECHRNG: // Channel number out of range
112
+ #endif
113
+ case EFBIG: // File too large
114
+ case EOVERFLOW: // Value too large to be stored in data type
115
+ case ERANGE: // Result too large
116
+ return absl::StatusCode::kOutOfRange;
117
+ #ifdef ENOPKG
118
+ case ENOPKG: // Package not installed
119
+ #endif
120
+ case ENOSYS: // Function not implemented
121
+ case ENOTSUP: // Operation not supported
122
+ case EAFNOSUPPORT: // Address family not supported
123
+ #ifdef EPFNOSUPPORT
124
+ case EPFNOSUPPORT: // Protocol family not supported
125
+ #endif
126
+ case EPROTONOSUPPORT: // Protocol not supported
127
+ #ifdef ESOCKTNOSUPPORT
128
+ case ESOCKTNOSUPPORT: // Socket type not supported
129
+ #endif
130
+ case EXDEV: // Improper link
131
+ return absl::StatusCode::kUnimplemented;
132
+ case EAGAIN: // Resource temporarily unavailable
133
+ #ifdef ECOMM
134
+ case ECOMM: // Communication error on send
135
+ #endif
136
+ case ECONNREFUSED: // Connection refused
137
+ case ECONNABORTED: // Connection aborted
138
+ case ECONNRESET: // Connection reset
139
+ case EINTR: // Interrupted function call
140
+ #ifdef EHOSTDOWN
141
+ case EHOSTDOWN: // Host is down
142
+ #endif
143
+ case EHOSTUNREACH: // Host is unreachable
144
+ case ENETDOWN: // Network is down
145
+ case ENETRESET: // Connection aborted by network
146
+ case ENETUNREACH: // Network unreachable
147
+ case ENOLCK: // No locks available
148
+ case ENOLINK: // Link has been severed
149
+ #ifdef ENONET
150
+ case ENONET: // Machine is not on the network
151
+ #endif
152
+ return absl::StatusCode::kUnavailable;
153
+ case EDEADLK: // Resource deadlock avoided
154
+ #ifdef ESTALE
155
+ case ESTALE: // Stale file handle
156
+ #endif
157
+ return absl::StatusCode::kAborted;
158
+ case ECANCELED: // Operation cancelled
159
+ return absl::StatusCode::kCancelled;
160
+ default:
161
+ return absl::StatusCode::kUnknown;
162
+ }
163
+ }
164
+
165
+ // POSIX `strerror_r()` returns `int`.
166
+ ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(int result, const char* buffer,
167
+ int error_code) {
168
+ if (ABSL_PREDICT_FALSE(result != 0)) {
169
+ return absl::StrCat("Unknown error ", error_code);
170
+ }
171
+ return buffer;
172
+ }
173
+
174
+ // GNU `strerror_r()` returns `char*`.
175
+ ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(char* result,
176
+ const char* buffer,
177
+ int error_code) {
178
+ return result;
179
+ }
180
+
181
+ std::string StrError(int error_code) {
182
+ char message[256];
183
+ return StrErrorResult(strerror_r(error_code, message, sizeof(message)),
184
+ message, error_code);
185
+ }
186
+
187
+ } // namespace
188
+
189
+ absl::Status ErrnoToCanonicalStatus(int error_number,
190
+ absl::string_view message) {
191
+ return absl::Status(ErrnoToCode(error_number),
192
+ absl::StrCat(message, ": ", StrError(error_number)));
193
+ }
194
+
195
+ } // namespace csrblocksparse
sparse_matmul/layers/errno_mapping.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
16
+ #define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
17
+
18
+ #include "absl/status/status.h"
19
+ #include "absl/strings/string_view.h"
20
+
21
+ namespace csrblocksparse {
22
+
23
+ // Converts |error_number| value to absl::Status.
24
+ absl::Status ErrnoToCanonicalStatus(int error_number,
25
+ absl::string_view message);
26
+
27
+ } // namespace csrblocksparse
28
+
29
+ #endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
sparse_matmul/layers/masked_sparse_matrix.h ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
19
+
20
+ #include <algorithm>
21
+ #include <cstdio>
22
+ #include <numeric>
23
+ #include <vector>
24
+
25
+ #include "absl/strings/str_format.h"
26
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
27
+
28
+ namespace csrblocksparse {
29
+
30
+ // MaskedSparseMatrix serves two purposes:
31
+ // 1) It is useful as a reference implementation of SpMV for correctness
32
+ // checking the much more complicated implementations in CSRBlockSparseMatrix
33
+ // 2) This is the format that sparse matrices are represented after pruning
34
+ // in TF. This class provides a bridge to getting these parameters into
35
+ // a compressed form suitable for computation and serialization.
36
+ //
37
+ // MaskedSparseMatrix<float> matrix(rows, cols, mask_from_tf, values_from_tf);
38
+ // CSRBlockSparseMatrix<float, bfloat16, int16_t> csr_matrix(matrix);
39
+ // csr_matrix.Multiply(rhs, bias, &out);
40
+ template <typename T>
41
+ class MaskedSparseMatrix {
42
+ public:
43
+ MaskedSparseMatrix() {}
44
+
45
+ // Construct a MaskedSparseMatrix of the given size, sparsity and block size.
46
+ // This is mainly useful for testing.
47
+ MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1,
48
+ int block_width = 1, float constant = 1.f,
49
+ bool random = true)
50
+ : rows_(rows), cols_(cols), sparsity_(sparsity) {
51
+ CHECK_EQ(rows % block_height, 0);
52
+ CHECK_EQ(cols % block_width, 0);
53
+
54
+ init(sparsity, block_height, block_width, constant, random);
55
+ }
56
+
57
+ // Construct from an existing mask and values (most likely from a TF model).
58
+ template <typename MaskType>
59
+ MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values)
60
+ : rows_(rows), cols_(cols) {
61
+ mask_.resize(rows * cols);
62
+ values_.resize(rows * cols);
63
+ std::copy_n(mask, rows * cols, mask_.begin());
64
+ std::copy_n(values, rows * cols, values_.begin());
65
+ sparsity_ =
66
+ 1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size();
67
+ }
68
+
69
+ const std::vector<int>& mask() const { return mask_; }
70
+ const std::vector<T>& values() const { return values_; }
71
+ T* data() { return values_.data(); }
72
+ const T* data() const { return values_.data(); }
73
+
74
+ int rows() const { return rows_; }
75
+ int cols() const { return cols_; }
76
+ float sparsity() const { return sparsity_; }
77
+
78
+ void Print() const {
79
+ absl::PrintF("-------Values---------\n");
80
+ for (int r = 0; r < rows_; ++r) {
81
+ for (int c = 0; c < cols_; ++c) {
82
+ absl::PrintF("%+6.3f ", static_cast<float>(values_[r * cols_ + c]));
83
+ }
84
+ absl::PrintF("\n");
85
+ }
86
+ absl::PrintF("-------Mask---------\n");
87
+ for (int r = 0; r < rows_; ++r) {
88
+ for (int c = 0; c < cols_; ++c) {
89
+ printf("%2d ", mask_[r * cols_ + c]);
90
+ }
91
+ absl::PrintF("\n");
92
+ }
93
+ }
94
+
95
+ // This routine is useful for rounding the possibly higher precision values
96
+ // stored in this class to a lower precision, so that correctness checks
97
+ // between this class and CSRBlockSparseMatrix can have a tighter tolerance.
98
+ template <typename U>
99
+ void CastWeights() {
100
+ for (int i = 0; i < values_.size(); ++i) {
101
+ values_[i] = static_cast<T>(U(values_[i]));
102
+ }
103
+ }
104
+
105
+ // Only meant for correctness checking.
106
+ // RhsClassType is meant to be either CacheAlignedVector OR
107
+ // FatCacheAlignedVector.
108
+ // The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR.
109
+ // |bias| is broadcast if |rhs| has more than one column.
110
+ template <typename RhsClassType, typename BiasType, typename OutClassType,
111
+ typename RhsType = typename RhsClassType::value_type,
112
+ typename OutType = typename OutClassType::value_type>
113
+ void SpMM_bias(const RhsClassType& rhs,
114
+ const CacheAlignedVector<BiasType>& bias, OutClassType* out,
115
+ bool relu = false) {
116
+ for (int r = 0; r < rows_; ++r) {
117
+ for (int n = 0; n < rhs.cols(); ++n) {
118
+ float sum = 0.f;
119
+ const RhsType* rhs_ptr = rhs.data() + n * rhs.rows();
120
+ OutType* out_ptr = out->data() + n * out->rows();
121
+ const int* mask_ptr = mask_.data() + r * cols_;
122
+ const T* value_ptr = values_.data() + r * cols_;
123
+ for (int c = 0; c < cols_; ++c) {
124
+ sum += mask_ptr[c] * static_cast<float>(value_ptr[c]) *
125
+ static_cast<float>(rhs_ptr[c]);
126
+ }
127
+ out_ptr[r] = static_cast<OutType>(
128
+ relu ? std::max(sum + static_cast<float>(bias[r]), 0.f)
129
+ : sum + static_cast<float>(bias[r]));
130
+ }
131
+ }
132
+ }
133
+
134
+ private:
135
+ // Generate a random matrix with the specified sparsity.
136
+ // Useful for testing.
137
+ void init(float sparsity, int block_height, int block_width, float constant,
138
+ bool random = true) {
139
+ int reduced_rows = rows_ / block_height;
140
+ int reduced_cols = cols_ / block_width;
141
+ mask_.resize(rows_ * cols_, 0);
142
+
143
+ // Fill with non-zero value to make sure masking works.
144
+ values_.resize(rows_ * cols_, static_cast<T>(2.f));
145
+
146
+ std::mt19937 generator(0);
147
+ std::uniform_real_distribution<float> dist_sparsity;
148
+ std::uniform_real_distribution<float> dist_value(-1.f, 1.f);
149
+ int nnz = 0;
150
+ while (nnz == 0) {
151
+ for (int r = 0; r < reduced_rows; ++r) {
152
+ for (int c = 0; c < reduced_cols; ++c) {
153
+ if (dist_sparsity(generator) > sparsity) {
154
+ nnz++;
155
+ for (int i = 0; i < block_height; ++i) {
156
+ for (int j = 0; j < block_width; ++j) {
157
+ mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1;
158
+ values_[(r * block_height + i) * cols_ + block_width * c + j] =
159
+ static_cast<T>(random ? dist_value(generator) : constant);
160
+ }
161
+ }
162
+ }
163
+ }
164
+ }
165
+ }
166
+ }
167
+
168
+ std::vector<int> mask_;
169
+ std::vector<T> values_;
170
+ int rows_;
171
+ int cols_;
172
+ float sparsity_;
173
+ };
174
+
175
+ template <typename T>
176
+ class MaskedLinearLayer {
177
+ public:
178
+ MaskedLinearLayer(MaskedSparseMatrix<T>&& weights,
179
+ CacheAlignedVector<T>&& bias)
180
+ : weights_(std::move(weights)), bias_(std::move(bias)) {}
181
+
182
+ MaskedLinearLayer() {}
183
+
184
+ template <typename U>
185
+ void CastWeights() {
186
+ weights_.template CastWeights<U>();
187
+ }
188
+
189
+ // Does Ax + b where A is a masked sparse ROW MAJOR matrix and
190
+ // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
191
+ // broadcast is rhs has more than one column.
192
+ template <typename FatVector>
193
+ void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) {
194
+ static_assert(std::is_same<typename FatVector::value_type, T>::value,
195
+ "FatVector value_type must match masked_linear_layer type");
196
+ weights_.SpMM_bias(rhs, bias_, out, relu);
197
+ }
198
+
199
+ private:
200
+ MaskedSparseMatrix<T> weights_;
201
+ CacheAlignedVector<T> bias_;
202
+ };
203
+
204
+ } // namespace csrblocksparse
205
+
206
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
sparse_matmul/layers/read_array_ifstream.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ // Low-level array reading function using std::ifstream.
18
+
19
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
20
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
21
+
22
+ #include <cstdint>
23
+ #include <fstream>
24
+ #include <sstream>
25
+ #include <string>
26
+
27
+ #include "absl/status/status.h"
28
+ #include "absl/strings/substitute.h"
29
+ #include "include/ghc/filesystem.hpp"
30
+
31
+ namespace csrblocksparse {
32
+ namespace detail {
33
+
34
+ template <typename T>
35
+ absl::Status ReadArrayIfstream(const std::string& file_name,
36
+ const std::string& path, std::vector<T>* array,
37
+ int64_t* length) {
38
+ ghc::filesystem::path complete_path(path);
39
+ complete_path /= file_name;
40
+ std::ifstream in_stream(complete_path.u8string(), std::ios::binary);
41
+ if (!in_stream.is_open()) {
42
+ return absl::UnknownError(
43
+ absl::Substitute("Error opening $0", complete_path.string()));
44
+ }
45
+
46
+ std::stringstream buffer;
47
+ buffer << in_stream.rdbuf();
48
+ if (buffer.str().empty()) {
49
+ LOG(ERROR) << "File " << complete_path << " was empty.";
50
+ return absl::UnknownError(
51
+ absl::Substitute("File $0 was empty", complete_path.string()));
52
+ }
53
+ std::string contents = buffer.str();
54
+ *length = contents.length();
55
+ int64_t elem = (*length + sizeof(T) - 1) / sizeof(T);
56
+ array->resize(elem);
57
+ std::move(contents.begin(), contents.end(),
58
+ reinterpret_cast<char*>(array->data()));
59
+
60
+ return absl::OkStatus();
61
+ }
62
+
63
+ } // namespace detail
64
+ } // namespace csrblocksparse
65
+
66
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
sparse_matmul/layers/sparse_linear_layer.h ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright 2021 Google LLC
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
18
+ #define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
19
+
20
+ #include <cstdint>
21
+
22
+ #include "absl/memory/memory.h"
23
+ #include "glog/logging.h"
24
+ #include "sparse_matmul/layers/csr_blocksparse_matrix.h"
25
+ #include "sparse_matmul/layers/masked_sparse_matrix.h"
26
+ #include "sparse_matmul/numerics/type_utils.h"
27
+ #include "sparse_matmul/os/coop_threads.h"
28
+ #include "sparse_matmul/vector/cache_aligned_vector.h"
29
+
30
+ namespace csrblocksparse {
31
+
32
+ template <typename WeightType, typename RhsType,
33
+ typename BiasType = typename TypeOfProduct<WeightType, RhsType>::type,
34
+ typename DeltaType = int16_t>
35
+ class SparseLinearLayer {
36
+ public:
37
+ SparseLinearLayer() {}
38
+
39
+ SparseLinearLayer(CsrBlockSparseMatrix<WeightType, RhsType>&& sparse_matrix,
40
+ CacheAlignedVector<BiasType>&& bias)
41
+ : sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) {
42
+ CHECK_EQ(sparse_matrix_.rows(), full_bias_.size());
43
+ // Some kernels expect that the bias is divided by 4, so we store a second
44
+ // copy of a quarter of the bias.
45
+ // TODO(b/189958858): Remove the quartered bias if it can be done without
46
+ // loss of speed, and rename the |full_bias_| member back to |bias_|.
47
+ bias_ = full_bias_;
48
+ for (int i = 0; i < bias_.size(); ++i) {
49
+ bias_[i] = static_cast<BiasType>(.25f * static_cast<float>(bias_[i]));
50
+ }
51
+ }
52
+ SparseLinearLayer(
53
+ const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
54
+ *this = src;
55
+ }
56
+ SparseLinearLayer& operator=(
57
+ const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
58
+ sparse_matrix_ = src.sparse_matrix_;
59
+ bias_ = src.bias_;
60
+ full_bias_ = src.full_bias_;
61
+ mid_output_ = src.mid_output_;
62
+ thread_layers_ = src.thread_layers_;
63
+ num_threads_ = src.num_threads_;
64
+ if (src.split_pc_) {
65
+ split_pc_ = absl::make_unique<ProducerConsumer>(
66
+ src.split_pc_->num_producers(), src.split_pc_->num_consumers());
67
+ }
68
+ return *this;
69
+ }
70
+
71
+ // Does Ax + b where A is a block sparse compressed sparse row matrix and
72
+ // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
73
+ // broadcast if rhs has more than one column.
74
+ template <typename RhsClassType, typename OutType>
75
+ void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false,
76
+ int tid = 0, SpinBarrier* barrier = nullptr) const {
77
+ static_assert(
78
+ std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
79
+ sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier);
80
+ }
81
+ // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
82
+ // and then samples from the output (softmax distribution) layer.
83
+ template <typename RhsClassType, typename OutType>
84
+ int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature,
85
+ int tid, SpinBarrier* barrier, std::minstd_rand* gen,
86
+ CacheAlignedVector<float>* scratch) const {
87
+ static_assert(
88
+ std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
89
+ return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid,
90
+ barrier, gen, scratch);
91
+ }
92
+ template <typename RhsClassType, typename OutType>
93
+ void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas,
94
+ int output_stride, OutType* output,
95
+ SpinBarrier* barrier = nullptr) {
96
+ static_assert(
97
+ std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
98
+ #ifdef __AVX2__
99
+ if (block_width() == 4 && (block_height() == 4 || block_height() == 8) &&
100
+ !IsCustomFloatType<WeightType>::value) {
101
+ if (!IsSplit()) {
102
+ sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu,
103
+ tid, replicas, output_stride, output->data());
104
+ if (barrier != nullptr) barrier->barrier();
105
+ return;
106
+ }
107
+ // NOTE: Until the quartered bias is removed it is a bad idea to split
108
+ // for ARM in the same way, as we would have to quarter the output of
109
+ // the first part of the split before running the second part.
110
+ // Signal completion of the previous MatVec.
111
+ split_pc_->produce();
112
+ PartLinearLayer& thread_part = thread_layers_[tid];
113
+ auto offset_output =
114
+ sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid);
115
+ auto mid_output =
116
+ sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid);
117
+ auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput(
118
+ mid_output_.cast_data(), tid);
119
+ // We can continue to consume the data that this thread produced and
120
+ // compute just the |self_matrix| part.
121
+ // No |relu| or |replicas|, as this is only a partial matmul.
122
+ // |tid| is always zero because the matrix has been split by tid.
123
+ thread_part.self_matrix.MatVec(
124
+ rhs.cast_data(), thread_part.full_bias.cast_data(), /*relu=*/false,
125
+ /*tid=*/0, /*replicas=*/1, output_stride, mid_output);
126
+ // We have to wait for the other threads to finish working on the previous
127
+ // MatMul before consuming the rest of |rhs|.
128
+ split_pc_->consume();
129
+ thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu,
130
+ /*tid=*/0, replicas, output_stride,
131
+ offset_output);
132
+ return;
133
+ }
134
+ #endif
135
+ DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API";
136
+ if (IsSplit()) {
137
+ // Generics aren't setup to use a split matrix. This will be inefficient.
138
+ split_pc_->produce();
139
+ split_pc_->consume();
140
+ }
141
+ if (block_height() == 8) {
142
+ // We are currently forced to use MatVec generics for this case.
143
+ LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!";
144
+ sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid,
145
+ replicas, output_stride, output->data());
146
+ if (barrier != nullptr) barrier->barrier();
147
+ } else {
148
+ sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier);
149
+ }
150
+ }
151
+
152
+ int rows() const { return sparse_matrix_.rows(); }
153
+ int cols() const { return sparse_matrix_.cols(); }
154
+ float sparsity() const { return sparse_matrix_.sparsity(); }
155
+ int block_width() const { return sparse_matrix_.block_width(); }
156
+ int block_height() const { return sparse_matrix_.block_height(); }
157
+ int num_threads() const { return sparse_matrix_.num_threads(); }
158
+ const CacheAlignedVector<BiasType>& bias() const { return bias_; }
159
+ const std::vector<int>& split_points() const {
160
+ return sparse_matrix_.split_points();
161
+ }
162
+ bool IsSplit() const {
163
+ return !thread_layers_.empty() && split_pc_ != nullptr;
164
+ }
165
+
166
+ std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); }
167
+ void Print() const {
168
+ printf("Matrix\n");
169
+ sparse_matrix_.Print();
170
+ printf("Bias\n");
171
+ bias_.Print();
172
+ }
173
+
174
+ // Combines adjacent row blocks, doubling the block height.
175
+ // This necessarily involves adding zero weights where the blocks don't align
176
+ // across adjacent pairs of rows, so use with caution, as the resulting matrix
177
+ // is most likely to run slower if very sparse to begin with.
178
+ // In the few cases where the blocks do mostly align, the resulting matmul
179
+ // could be much faster, as the number of reads of the rhs will be halved.
180
+ void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); }
181
+
182
+ // Cache_line_size is provided only for testing. Normally uses a value for
183
+ // the current architecture.
184
+ int PrepareForThreads(int num_threads, int cache_line_size = -1) {
185
+ num_threads_ = num_threads;
186
+ if (num_threads_ > 1) {
187
+ split_pc_ =
188
+ absl::make_unique<ProducerConsumer>(num_threads_, num_threads_);
189
+ } else {
190
+ split_pc_.reset(nullptr);
191
+ }
192
+ return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size);
193
+ }
194
+
195
+ // Partitions the matrix into pieces by thread.
196
+ // In this matrix, we can go ahead and calculate the part that only depends
197
+ // on rhs inputs that were generated by this thread in the previous matvec,
198
+ // without having to use any thread synchronization, and only after that do we
199
+ // have to wait for the other threads to finish the previous matvec.
200
+ // So we split the matrix using the |split_points| from the previous matrix
201
+ // into 2 * |num_threads_| pieces: self and other for each thread, being the
202
+ // parts that can be calculated before and after the other threads have
203
+ // completed their calculation of the previous matvec.
204
+ // We then have to use a ProducerConsumer lock instead of a SpinBarrier to
205
+ // synchronize the data produced by the other threads.
206
+ void SliceForThreads(const std::vector<int>& split_points) {
207
+ thread_layers_.clear();
208
+ thread_layers_.reserve(num_threads_);
209
+ LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for "
210
+ << num_threads_ << " threads";
211
+ for (int tid = 0; tid < num_threads_; ++tid) {
212
+ thread_layers_.emplace_back(
213
+ sparse_matrix_, full_bias_, bias_, tid,
214
+ split_points[tid] * sparse_matrix_.block_height(),
215
+ split_points[tid + 1] * sparse_matrix_.block_height());
216
+ }
217
+ mid_output_ =
218
+ std::move(csrblocksparse::CacheAlignedVector<BiasType>(rows()));
219
+ mid_output_.FillZero();
220
+ }
221
+
222
+ // Splits the layer by inputs into 2 equal pieces. Each of the resulting
223
+ // layers should be computed independently on the first and second halves of
224
+ // the inputs respectively and the results added to achieve the same effect
225
+ // as the original layer.
226
+ void SplitInputs(
227
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
228
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
229
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
230
+ sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2));
231
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix2(
232
+ sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2,
233
+ sparse_matrix_.cols()));
234
+ *part1 =
235
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
236
+ std::move(matrix1),
237
+ std::move(CacheAlignedVector<BiasType>(full_bias_))));
238
+ CacheAlignedVector<BiasType> bias2(sparse_matrix_.rows());
239
+ bias2.FillZero();
240
+ *part2 =
241
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
242
+ std::move(matrix2), std::move(bias2)));
243
+ }
244
+
245
+ // Splits the layer by outputs into 2 equal pieces. Each of the resulting
246
+ // layers should be computed independently on the full inputs and the results
247
+ // concatenated to achieve the same effect as the original layer.
248
+ void SplitOutputs(
249
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
250
+ SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
251
+ LOG(INFO) << "input rows=" << sparse_matrix_.rows()
252
+ << ", cols=" << sparse_matrix_.cols();
253
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
254
+ sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2));
255
+ CsrBlockSparseMatrix<WeightType, RhsType> matrix2(sparse_matrix_.SplitByRow(
256
+ sparse_matrix_.rows() / 2, sparse_matrix_.rows()));
257
+ CacheAlignedVector<BiasType> bias1(full_bias_, 0, full_bias_.size() / 2);
258
+ *part1 =
259
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
260
+ std::move(matrix1), std::move(bias1)));
261
+ CacheAlignedVector<BiasType> bias2(full_bias_, full_bias_.size() / 2,
262
+ full_bias_.size());
263
+ *part2 =
264
+ std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
265
+ std::move(matrix2), std::move(bias2)));
266
+ }
267
+
268
+ private:
269
+ // Simple struct to hold a partitioned layer.
270
+ struct PartLinearLayer {
271
+ // The original matrix is first split by row to generate only the outputs
272
+ // for the given tid. The |row_sub_matrix| is then split by column into two
273
+ // partitions:
274
+ // self is the part for which the rhs elements in [|start_col|, |end_col|)
275
+ // were generated by this thread in some previous matmul.
276
+ // |other| is the rest of the columns that require rhs elements from other
277
+ // threads.
278
+ // NOTE that| start_col|, |end_col| are in raw columns, not blocks.
279
+ PartLinearLayer(const CsrBlockSparseMatrix<WeightType, RhsType>& matrix,
280
+ const CacheAlignedVector<BiasType>& bias,
281
+ const CacheAlignedVector<BiasType>& bias_4, int tid,
282
+ int start_col, int end_col) {
283
+ int block_height = matrix.block_height();
284
+ // Split the input matrix by row, selecting only the rows relevant to
285
+ // thread tid.
286
+ int start_row = matrix.split_points()[tid] * block_height;
287
+ int end_row = matrix.split_points()[tid + 1] * block_height;
288
+ LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows ["
289
+ << start_row << "," << end_row << ")";
290
+ CsrBlockSparseMatrix<WeightType, RhsType> row_sub_matrix =
291
+ matrix.SplitByRow(start_row, end_row);
292
+ // Partition into the columns that use rhs elements that thread tid
293
+ // produced in a previous matmul, and the other rhs elements.
294
+ // NOTE that we |keep_rhs_size|=true so that each matrix can operate on
295
+ // the same rhs input vector. The self matrix just guarantees not to
296
+ // access any of the elements that are generated by another thread.
297
+ self_matrix = std::move(row_sub_matrix.SplitByColumn(
298
+ start_col, end_col, /*keep_rhs_size=*/true));
299
+ self_matrix.PrepareForThreads(1);
300
+ // The reversed start and end slice out the complement of [start, end).
301
+ other_matrix = std::move(row_sub_matrix.SplitByColumn(
302
+ end_col, start_col, /*keep_rhs_size=*/true));
303
+ other_matrix.PrepareForThreads(1);
304
+ full_bias =
305
+ std::move(CacheAlignedVector<BiasType>(bias, start_row, end_row));
306
+ // TODO(b/189958858): Eliminate the quarter bias from all the code.
307
+ quarter_bias =
308
+ std::move(CacheAlignedVector<BiasType>(bias_4, start_row, end_row));
309
+ }
310
+ // The part of the matrix that only depends on this thread for rhs inputs.
311
+ CsrBlockSparseMatrix<WeightType, RhsType> self_matrix;
312
+ CacheAlignedVector<BiasType> full_bias;
313
+ CacheAlignedVector<BiasType> quarter_bias;
314
+ // The part of the matrix that uses rhs inputs from other threads.
315
+ CsrBlockSparseMatrix<WeightType, RhsType> other_matrix;
316
+ };
317
+ CsrBlockSparseMatrix<WeightType, RhsType, DeltaType> sparse_matrix_;
318
+ CacheAlignedVector<BiasType> bias_;
319
+ CacheAlignedVector<BiasType> full_bias_;
320
+ // Output from the self_matrix that will be given to |other_matrix| as bias.
321
+ CacheAlignedVector<BiasType> mid_output_;
322
+ // One partitioned pair of matrices for each thread.
323
+ std::vector<PartLinearLayer> thread_layers_;
324
+ // Producer-consumer lock used to wait between computing |self_matrix| and
325
+ // |other_matrix| for the other threads to finish the *previous* matvec.
326
+ std::unique_ptr<ProducerConsumer> split_pc_;
327
+ int num_threads_ = 0;
328
+ };
329
+
330
+ template <typename WeightType, typename RhsType>
331
+ SparseLinearLayer<WeightType, RhsType> CreateRandomLayer(int rows, int cols,
332
+ float sparsity,
333
+ int block_height = 1,
334
+ int block_width = 1) {
335
+ typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
336
+ CacheAlignedVector<BiasType> bias(rows);
337
+ bias.FillRandom();
338
+
339
+ auto masked_matrix = MaskedSparseMatrix<float>(rows, cols, sparsity,
340
+ block_height, block_width);
341
+ auto sparse_matrix = CsrBlockSparseMatrix<WeightType, RhsType>(masked_matrix);
342
+
343
+ return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
344
+ std::move(bias));
345
+ }
346
+
347
+ template <typename WeightType, typename RhsType>
348
+ SparseLinearLayer<WeightType, RhsType> CreateConstantLayer(
349
+ int rows, int cols, float sparsity, float constant = 1.f) {
350
+ typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
351
+ CacheAlignedVector<BiasType> bias(rows);
352
+ bias.FillOnes();
353
+
354
+ MaskedSparseMatrix<float> masked_matrix(rows, cols, sparsity,
355
+ /*block_height=*/1, /*block_width=*/1,
356
+ constant, /*random=*/false);
357
+ CsrBlockSparseMatrix<WeightType, RhsType> sparse_matrix(masked_matrix);
358
+
359
+ return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
360
+ std::move(bias));
361
+ }
362
+
363
+ } // namespace csrblocksparse
364
+
365
+ #endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
sparse_matmul/layers/sparse_linear_layer_test.cc ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include "sparse_matmul/layers/sparse_linear_layer.h"
16
+
17
+ #include "gmock/gmock.h"
18
+ #include "gtest/gtest.h"
19
+ #include "sparse_matmul/numerics/test_utils.h"
20
+
21
+ namespace csrblocksparse {
22
+ namespace {
23
+
24
+ constexpr int kBlockSize = 4;
25
+ constexpr int kSize = 256;
26
+ constexpr int kNumThreads = 4;
27
+ constexpr int kCols = 1;
28
+
29
+ void SlicedThreadBody(SpinBarrier* spin_barrier, int tid,
30
+ const FatCacheAlignedVector<float>& rhs,
31
+ SparseLinearLayer<float, float>* sparse_linear_layer,
32
+ FatCacheAlignedVector<float>* out, bool use_relu) {
33
+ sparse_linear_layer->MatVec(rhs, use_relu, tid, /*replicas=*/1,
34
+ /*output_stride=*/0, out);
35
+ spin_barrier->barrier();
36
+ }
37
+
38
+ // Tests that a Layer that has been SliceForThreads computes the same result as
39
+ // the original layer. This is a basic test that all the slicing didn't mess up
40
+ // any of the computations.
41
+ TEST(CsrBlockSparseMatrix, SliceForThreads) {
42
+ MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
43
+ FatCacheAlignedVector<float> rhs(kSize, kCols);
44
+ CacheAlignedVector<float> bias(kSize);
45
+ FatCacheAlignedVector<float> out1(kSize, kCols);
46
+
47
+ bias.FillRandom();
48
+ rhs.FillRandom();
49
+ out1.FillZero();
50
+ FatCacheAlignedVector<float> out_reference = out1;
51
+ CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
52
+ SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
53
+ std::move(bias));
54
+ sparse_linear_layer.PrepareForThreads(1);
55
+ sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
56
+ /*output_stride=*/0, &out_reference);
57
+ std::vector<int> fake_split_points = {0, 48 / kBlockSize, 128 / kBlockSize,
58
+ 208 / kBlockSize, kSize / kBlockSize};
59
+ sparse_linear_layer.PrepareForThreads(kNumThreads);
60
+ sparse_linear_layer.SliceForThreads(fake_split_points);
61
+ csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, SlicedThreadBody, rhs,
62
+ &sparse_linear_layer, &out1,
63
+ /*relu=*/true);
64
+
65
+ CheckResult(out_reference, out1, kCols);
66
+ }
67
+
68
+ void LayersThreadBody(SpinBarrier* spin_barrier, int tid,
69
+ const FatCacheAlignedVector<float>& rhs,
70
+ SparseLinearLayer<float, float>* sparse_linear_layer1,
71
+ SparseLinearLayer<float, float>* sparse_linear_layer2,
72
+ FatCacheAlignedVector<float>* out1,
73
+ FatCacheAlignedVector<float>* out2, bool use_relu) {
74
+ sparse_linear_layer1->MatVec(rhs, use_relu, tid, /*replicas=*/1,
75
+ /*output_stride=*/0, out1);
76
+ // NOTE no barrier here!
77
+ sparse_linear_layer2->MatVec(*out1, use_relu, tid, /*replicas=*/1,
78
+ /*output_stride=*/0, out2);
79
+ spin_barrier->barrier();
80
+ }
81
+
82
+ // Tests that a pair of layers computes the same result whether or not the
83
+ // second layer has been SliceForThreads. This is a more critical test that
84
+ // the replacement of barriers with producer-consumer locks works.
85
+ // Must be run with tsan to really test it properly.
86
+ TEST(CsrBlockSparseMatrix, SliceForThreadsLayers) {
87
+ MaskedSparseMatrix<float> matrix1(kSize, kSize, 0.95, kBlockSize, kBlockSize);
88
+ FatCacheAlignedVector<float> rhs(kSize, kCols);
89
+ CacheAlignedVector<float> bias1(kSize);
90
+ FatCacheAlignedVector<float> out1(kSize, kCols);
91
+ MaskedSparseMatrix<float> matrix2(kSize, kSize, 0.95, kBlockSize, kBlockSize);
92
+ CacheAlignedVector<float> bias2(kSize);
93
+ FatCacheAlignedVector<float> out2(kSize, kCols);
94
+
95
+ bias1.FillRandom();
96
+ rhs.FillRandom();
97
+ bias2.FillRandom();
98
+ out1.FillZero();
99
+ out2.FillZero();
100
+ FatCacheAlignedVector<float> out_reference = out2;
101
+ CsrBlockSparseMatrix<float, float> sparse_matrix1(matrix1);
102
+ SparseLinearLayer<float, float> layer1(std::move(sparse_matrix1),
103
+ std::move(bias1));
104
+ CsrBlockSparseMatrix<float, float> sparse_matrix2(matrix2);
105
+ SparseLinearLayer<float, float> layer2(std::move(sparse_matrix2),
106
+ std::move(bias2));
107
+ layer1.PrepareForThreads(1);
108
+ layer2.PrepareForThreads(1);
109
+ layer1.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
110
+ /*output_stride=*/0, &out1);
111
+ layer2.MatVec(out1, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
112
+ /*output_stride=*/0, &out_reference);
113
+ layer1.PrepareForThreads(kNumThreads);
114
+ layer2.PrepareForThreads(kNumThreads);
115
+ layer2.SliceForThreads(layer1.split_points());
116
+ csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, LayersThreadBody, rhs,
117
+ &layer1, &layer2, &out1, &out2,
118
+ /*relu=*/true);
119
+
120
+ CheckResult(out_reference, out2, kCols);
121
+ }
122
+
123
+ // Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
124
+ // result as original layer. (Float compute type).
125
+ TEST(CsrBlockSparseMatrix, Float8x4) {
126
+ using ComputeType = float;
127
+ using RhsType = float;
128
+ using BiasType = float;
129
+ MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
130
+ matrix.CastWeights<ComputeType>();
131
+ FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
132
+ CacheAlignedVector<BiasType> bias(kSize);
133
+ FatCacheAlignedVector<BiasType> out1(kSize, kCols);
134
+
135
+ bias.FillRandom();
136
+ rhs.FillRandom();
137
+ out1.FillZero();
138
+ FatCacheAlignedVector<BiasType> out_reference = out1;
139
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
140
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
141
+ std::move(sparse_matrix), std::move(bias));
142
+ sparse_linear_layer.PrepareForThreads(1);
143
+ sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
144
+ /*output_stride=*/0, &out_reference);
145
+ sparse_linear_layer.DoubleBlockHeight();
146
+ sparse_linear_layer.PrepareForThreads(1);
147
+ sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
148
+ /*output_stride=*/0, &out1);
149
+ CheckResult(out_reference, out1, kCols);
150
+ }
151
+
152
+ // Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
153
+ // result as original layer. (Fixed16 compute type).
154
+ TEST(CsrBlockSparseMatrix, Fixed8x4) {
155
+ using ComputeType = csrblocksparse::fixed16<4>;
156
+ using RhsType = csrblocksparse::fixed16<4>;
157
+ using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
158
+ MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
159
+ matrix.CastWeights<ComputeType>();
160
+ FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
161
+ CacheAlignedVector<BiasType> bias(kSize);
162
+ FatCacheAlignedVector<BiasType> out1(kSize, kCols);
163
+
164
+ bias.FillRandom();
165
+ rhs.FillRandom();
166
+ out1.FillZero();
167
+ FatCacheAlignedVector<BiasType> out_reference = out1;
168
+ CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
169
+ SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
170
+ std::move(sparse_matrix), std::move(bias));
171
+ sparse_linear_layer.PrepareForThreads(1);
172
+ sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
173
+ /*output_stride=*/0, &out_reference);
174
+ sparse_linear_layer.DoubleBlockHeight();
175
+ sparse_linear_layer.PrepareForThreads(1);
176
+ sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
177
+ /*output_stride=*/0, &out1);
178
+ CheckResult(out_reference, out1, kCols);
179
+ }
180
+
181
+ TEST(SparseLinearLayerTest, PrintCompiles) {
182
+ SparseLinearLayer<float, float> sparse_linear_layer;
183
+ sparse_linear_layer.Print();
184
+ }
185
+
186
+ } // namespace
187
+ } // namespace csrblocksparse
sparse_matmul/layers/status_macros.h ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2021 Google LLC
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
16
+ #define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
17
+
18
+ #include "absl/status/status.h"
19
+ #include "absl/status/statusor.h"
20
+
21
+ #define SPARSE_MATMUL_RETURN_IF_ERROR(expr) \
22
+ do { \
23
+ const absl::Status _status = (expr); \
24
+ if (!_status.ok()) return _status; \
25
+ } while (0)
26
+ template <typename T>
27
+ absl::Status DoAssignOrReturn(T& lhs, absl::StatusOr<T> result) {
28
+ if (result.ok()) {
29
+ lhs = result.value();
30
+ }
31
+ return result.status();
32
+ }
33
+
34
+ #endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f861af29b1f767830d74ef83874944b18d80157b6b0256fdc4c14fa79ec936
3
+ size 20852
sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2d534bde2caf6e59990a46b4b1907088b8144c53d62d97de7e2b4bdc956da68
3
+ size 5133
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11399f9d0e8f8dfbef6eb37e0c096f858658bc650f728a08f3135ccca44f0a5a
3
+ size 1062
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3d971e067a6df985d68beac26bcf4e9a6cc13ff328599e84d50a0fc9a7c103b
3
+ size 2382
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1376ef7a360699dae24a49f40a254990d4a70b844dadcdbe9dcbf1a306999a8
3
+ size 55829
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffcc8ccf086fccfacc928877aa29ef03ce51cce0f0b7d2aacf81782b7b527089
3
+ size 2003
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a16f98ba6f09031ea9fefb79fdc9ba90e44f0046ab70dab014ac971ca7f7186
3
+ size 4684