Spaces:
Runtime error
Runtime error
mrsu0994
commited on
Commit
·
154f182
1
Parent(s):
3bf15e2
upload f5-tts source
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- BUILD +44 -0
- Dockerfile +34 -0
- WORKSPACE +154 -0
- alphabet.txt +97 -0
- app.py +145 -0
- bazelisk-linux-amd64 +3 -0
- build_ext.sh +3 -0
- extract_tacotrons_model.py +8 -0
- extract_wavegru_model.py +12 -0
- inference.py +90 -0
- mono_tts_cbhg_small_0700000.ckpt +3 -0
- packages.txt +7 -0
- pooch.py +10 -0
- requirements.txt +13 -0
- sparse_matmul/BUILD +22 -0
- sparse_matmul/compute/BUILD +88 -0
- sparse_matmul/compute/ar_inputs.h +37 -0
- sparse_matmul/compute/gru_gates.h +214 -0
- sparse_matmul/compute/gru_gates_arm.h +288 -0
- sparse_matmul/compute/gru_gates_avx_fixed.h +348 -0
- sparse_matmul/compute/gru_gates_generic.h +97 -0
- sparse_matmul/compute/gru_gates_test.cc +164 -0
- sparse_matmul/compute/kernels_arm.h +0 -0
- sparse_matmul/compute/kernels_avx.h +601 -0
- sparse_matmul/compute/kernels_generic.h +273 -0
- sparse_matmul/compute/matmul.h +199 -0
- sparse_matmul/compute/matmul_fixed_avx2.cc +235 -0
- sparse_matmul/compute/matmul_fixed_avx2.h +49 -0
- sparse_matmul/compute/matmul_generic.cc +122 -0
- sparse_matmul/compute/matmul_generic.h +41 -0
- sparse_matmul/compute/thread_bounds.cc +106 -0
- sparse_matmul/compute/thread_bounds.h +74 -0
- sparse_matmul/layers/BUILD +146 -0
- sparse_matmul/layers/csr_blocksparse_matrix.h +835 -0
- sparse_matmul/layers/csrblocksparse_test.cc +977 -0
- sparse_matmul/layers/errno_mapping.cc +195 -0
- sparse_matmul/layers/errno_mapping.h +29 -0
- sparse_matmul/layers/masked_sparse_matrix.h +206 -0
- sparse_matmul/layers/read_array_ifstream.h +66 -0
- sparse_matmul/layers/sparse_linear_layer.h +365 -0
- sparse_matmul/layers/sparse_linear_layer_test.cc +187 -0
- sparse_matmul/layers/status_macros.h +34 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
bazelisk-linux-amd64 filter=lfs diff=lfs merge=lfs -text
|
BUILD
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [internal] load cc_fuzz_target.bzl
|
| 2 |
+
# [internal] load cc_proto_library.bzl
|
| 3 |
+
# [internal] load android_cc_test:def.bzl
|
| 4 |
+
|
| 5 |
+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
| 6 |
+
|
| 7 |
+
package(default_visibility = [":__subpackages__"])
|
| 8 |
+
|
| 9 |
+
licenses(["notice"])
|
| 10 |
+
|
| 11 |
+
# To run all cc_tests in this directory:
|
| 12 |
+
# bazel test //:all
|
| 13 |
+
|
| 14 |
+
# [internal] Command to run dsp_util_android_test.
|
| 15 |
+
|
| 16 |
+
# [internal] Command to run lyra_integration_android_test.
|
| 17 |
+
|
| 18 |
+
exports_files(
|
| 19 |
+
srcs = [
|
| 20 |
+
"wavegru_mod.cc",
|
| 21 |
+
],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
pybind_extension(
|
| 25 |
+
name = "wavegru_mod", # This name is not actually created!
|
| 26 |
+
srcs = ["wavegru_mod.cc"],
|
| 27 |
+
deps = [
|
| 28 |
+
"//sparse_matmul",
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
py_library(
|
| 33 |
+
name = "wavegru_mod",
|
| 34 |
+
data = [":wavegru_mod.so"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
py_binary(
|
| 38 |
+
name = "wavegru",
|
| 39 |
+
srcs = ["wavegru.py"],
|
| 40 |
+
deps = [
|
| 41 |
+
":wavegru_mod"
|
| 42 |
+
],
|
| 43 |
+
)
|
| 44 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# you will also find guides on how best to write your Dockerfile
|
| 3 |
+
|
| 4 |
+
FROM us-docker.pkg.dev/colab-images/public/runtime:latest
|
| 5 |
+
|
| 6 |
+
RUN apt update; apt install libsndfile1-dev make autoconf automake libtool gcc pkg-config -y python3-dev
|
| 7 |
+
|
| 8 |
+
WORKDIR /code
|
| 9 |
+
|
| 10 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 11 |
+
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 13 |
+
|
| 14 |
+
# Set up a new user named "user" with user ID 1000
|
| 15 |
+
RUN useradd -m -u 1000 user
|
| 16 |
+
|
| 17 |
+
# Switch to the "user" user
|
| 18 |
+
USER user
|
| 19 |
+
|
| 20 |
+
# Set home to the user's home directory
|
| 21 |
+
ENV HOME=/home/user \
|
| 22 |
+
PATH=/home/user/.local/bin:$PATH
|
| 23 |
+
|
| 24 |
+
# Set the working directory to the user's home directory
|
| 25 |
+
WORKDIR $HOME/app
|
| 26 |
+
|
| 27 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
| 28 |
+
COPY --chown=user . $HOME/app
|
| 29 |
+
|
| 30 |
+
RUN bash ./build_ext.sh
|
| 31 |
+
EXPOSE 7860
|
| 32 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 33 |
+
|
| 34 |
+
ENTRYPOINT ["python", "app.py"]
|
WORKSPACE
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
########################
|
| 2 |
+
# Platform Independent #
|
| 3 |
+
########################
|
| 4 |
+
|
| 5 |
+
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
|
| 6 |
+
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
| 7 |
+
|
| 8 |
+
# GoogleTest/GoogleMock framework.
|
| 9 |
+
git_repository(
|
| 10 |
+
name = "com_google_googletest",
|
| 11 |
+
remote = "https://github.com/google/googletest.git",
|
| 12 |
+
tag = "release-1.10.0",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Google benchmark.
|
| 16 |
+
http_archive(
|
| 17 |
+
name = "com_github_google_benchmark",
|
| 18 |
+
urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
|
| 19 |
+
strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
|
| 20 |
+
sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# proto_library, cc_proto_library, and java_proto_library rules implicitly
|
| 24 |
+
# depend on @com_google_protobuf for protoc and proto runtimes.
|
| 25 |
+
# This statement defines the @com_google_protobuf repo.
|
| 26 |
+
git_repository(
|
| 27 |
+
name = "com_google_protobuf",
|
| 28 |
+
remote = "https://github.com/protocolbuffers/protobuf.git",
|
| 29 |
+
tag = "v3.15.4",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
|
| 33 |
+
protobuf_deps()
|
| 34 |
+
|
| 35 |
+
# Google Abseil Libs
|
| 36 |
+
git_repository(
|
| 37 |
+
name = "com_google_absl",
|
| 38 |
+
remote = "https://github.com/abseil/abseil-cpp.git",
|
| 39 |
+
branch = "lts_2020_09_23",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Filesystem
|
| 43 |
+
# The new_* prefix is used because it is not a bazel project and there is
|
| 44 |
+
# no BUILD file in that repo.
|
| 45 |
+
FILESYSTEM_BUILD = """
|
| 46 |
+
cc_library(
|
| 47 |
+
name = "filesystem",
|
| 48 |
+
hdrs = glob(["include/ghc/*"]),
|
| 49 |
+
visibility = ["//visibility:public"],
|
| 50 |
+
)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
new_git_repository(
|
| 54 |
+
name = "gulrak_filesystem",
|
| 55 |
+
remote = "https://github.com/gulrak/filesystem.git",
|
| 56 |
+
tag = "v1.3.6",
|
| 57 |
+
build_file_content = FILESYSTEM_BUILD
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Audio DSP
|
| 61 |
+
git_repository(
|
| 62 |
+
name = "com_google_audio_dsp",
|
| 63 |
+
remote = "https://github.com/google/multichannel-audio-tools.git",
|
| 64 |
+
# There are no tags for this repo, we are synced to bleeding edge.
|
| 65 |
+
branch = "master",
|
| 66 |
+
repo_mapping = {
|
| 67 |
+
"@com_github_glog_glog" : "@com_google_glog"
|
| 68 |
+
}
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
http_archive(
|
| 73 |
+
name = "pybind11_bazel",
|
| 74 |
+
strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
|
| 75 |
+
urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
|
| 76 |
+
)
|
| 77 |
+
# We still require the pybind library.
|
| 78 |
+
http_archive(
|
| 79 |
+
name = "pybind11",
|
| 80 |
+
build_file = "@pybind11_bazel//:pybind11.BUILD",
|
| 81 |
+
strip_prefix = "pybind11-2.10.0",
|
| 82 |
+
urls = ["https://github.com/pybind/pybind11/archive/v2.10.0.tar.gz"],
|
| 83 |
+
)
|
| 84 |
+
load("@pybind11_bazel//:python_configure.bzl", "python_configure")
|
| 85 |
+
python_configure(name = "local_config_python")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Transitive dependencies of Audio DSP.
|
| 90 |
+
http_archive(
|
| 91 |
+
name = "eigen_archive",
|
| 92 |
+
build_file = "eigen.BUILD",
|
| 93 |
+
sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
|
| 94 |
+
strip_prefix = "eigen-eigen-049af2f56331",
|
| 95 |
+
urls = [
|
| 96 |
+
"http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
| 97 |
+
"https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
| 98 |
+
],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
http_archive(
|
| 102 |
+
name = "fft2d",
|
| 103 |
+
build_file = "fft2d.BUILD",
|
| 104 |
+
sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
|
| 105 |
+
urls = [
|
| 106 |
+
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
|
| 107 |
+
],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Google logging
|
| 111 |
+
git_repository(
|
| 112 |
+
name = "com_google_glog",
|
| 113 |
+
remote = "https://github.com/google/glog.git",
|
| 114 |
+
tag = "v0.5.0"
|
| 115 |
+
)
|
| 116 |
+
# Dependency for glog
|
| 117 |
+
git_repository(
|
| 118 |
+
name = "com_github_gflags_gflags",
|
| 119 |
+
remote = "https://github.com/mchinen/gflags.git",
|
| 120 |
+
branch = "android_linking_fix"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Bazel/build rules
|
| 124 |
+
|
| 125 |
+
http_archive(
|
| 126 |
+
name = "bazel_skylib",
|
| 127 |
+
urls = [
|
| 128 |
+
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
| 129 |
+
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
| 130 |
+
],
|
| 131 |
+
sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
|
| 132 |
+
)
|
| 133 |
+
load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
|
| 134 |
+
bazel_skylib_workspace()
|
| 135 |
+
|
| 136 |
+
http_archive(
|
| 137 |
+
name = "rules_android",
|
| 138 |
+
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
|
| 139 |
+
strip_prefix = "rules_android-0.1.1",
|
| 140 |
+
urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Google Maven Repository
|
| 144 |
+
GMAVEN_TAG = "20180625-1"
|
| 145 |
+
|
| 146 |
+
http_archive(
|
| 147 |
+
name = "gmaven_rules",
|
| 148 |
+
strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
|
| 149 |
+
url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
|
| 153 |
+
|
| 154 |
+
gmaven_rules()
|
alphabet.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_
|
| 2 |
+
■
|
| 3 |
+
|
| 4 |
+
!
|
| 5 |
+
,
|
| 6 |
+
.
|
| 7 |
+
:
|
| 8 |
+
?
|
| 9 |
+
a
|
| 10 |
+
b
|
| 11 |
+
c
|
| 12 |
+
d
|
| 13 |
+
e
|
| 14 |
+
g
|
| 15 |
+
h
|
| 16 |
+
i
|
| 17 |
+
k
|
| 18 |
+
l
|
| 19 |
+
m
|
| 20 |
+
n
|
| 21 |
+
o
|
| 22 |
+
p
|
| 23 |
+
q
|
| 24 |
+
r
|
| 25 |
+
s
|
| 26 |
+
t
|
| 27 |
+
u
|
| 28 |
+
v
|
| 29 |
+
x
|
| 30 |
+
y
|
| 31 |
+
à
|
| 32 |
+
á
|
| 33 |
+
â
|
| 34 |
+
ã
|
| 35 |
+
è
|
| 36 |
+
é
|
| 37 |
+
ê
|
| 38 |
+
ì
|
| 39 |
+
í
|
| 40 |
+
ò
|
| 41 |
+
ó
|
| 42 |
+
ô
|
| 43 |
+
õ
|
| 44 |
+
ù
|
| 45 |
+
ú
|
| 46 |
+
ý
|
| 47 |
+
ă
|
| 48 |
+
đ
|
| 49 |
+
ĩ
|
| 50 |
+
ũ
|
| 51 |
+
ơ
|
| 52 |
+
ư
|
| 53 |
+
ạ
|
| 54 |
+
ả
|
| 55 |
+
ấ
|
| 56 |
+
ầ
|
| 57 |
+
ẩ
|
| 58 |
+
ẫ
|
| 59 |
+
ậ
|
| 60 |
+
ắ
|
| 61 |
+
ằ
|
| 62 |
+
ẳ
|
| 63 |
+
ẵ
|
| 64 |
+
ặ
|
| 65 |
+
ẹ
|
| 66 |
+
ẻ
|
| 67 |
+
ẽ
|
| 68 |
+
ế
|
| 69 |
+
ề
|
| 70 |
+
ể
|
| 71 |
+
ễ
|
| 72 |
+
ệ
|
| 73 |
+
ỉ
|
| 74 |
+
ị
|
| 75 |
+
ọ
|
| 76 |
+
ỏ
|
| 77 |
+
ố
|
| 78 |
+
ồ
|
| 79 |
+
ổ
|
| 80 |
+
ỗ
|
| 81 |
+
ộ
|
| 82 |
+
ớ
|
| 83 |
+
ờ
|
| 84 |
+
ở
|
| 85 |
+
ỡ
|
| 86 |
+
ợ
|
| 87 |
+
ụ
|
| 88 |
+
ủ
|
| 89 |
+
ứ
|
| 90 |
+
ừ
|
| 91 |
+
ử
|
| 92 |
+
ữ
|
| 93 |
+
ự
|
| 94 |
+
ỳ
|
| 95 |
+
ỵ
|
| 96 |
+
ỷ
|
| 97 |
+
ỹ
|
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## build wavegru-cpp
|
| 2 |
+
# import os
|
| 3 |
+
# os.system("./bazelisk-linux-amd64 clean --expunge")
|
| 4 |
+
# os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
|
| 5 |
+
|
| 6 |
+
# install espeak
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import unicodedata
|
| 10 |
+
|
| 11 |
+
import regex
|
| 12 |
+
|
| 13 |
+
if not os.path.isfile("./wavegru_mod.so"):
|
| 14 |
+
os.system("bash ./build_ext.sh")
|
| 15 |
+
|
| 16 |
+
import gradio as gr
|
| 17 |
+
|
| 18 |
+
from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
|
| 19 |
+
from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
|
| 20 |
+
|
| 21 |
+
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
| 22 |
+
"./alphabet.txt", "./tacotron.toml", "./mono_tts_cbhg_small_0700000.ckpt"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
wavegru_config, wavegru_net = load_wavegru_net(
|
| 26 |
+
"./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_0400000.ckpt"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
| 30 |
+
wavecpp = load_wavegru_cpp(wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
space_re = regex.compile(r"\s+")
|
| 34 |
+
number_re = regex.compile("([0-9]+)")
|
| 35 |
+
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
|
| 36 |
+
num_re = regex.compile(r"([0-9.,]*[0-9])")
|
| 37 |
+
alphabet_ = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
|
| 38 |
+
keep_text_and_num_re = regex.compile(rf"[^\s{alphabet_}.,0-9]")
|
| 39 |
+
keep_text_re = regex.compile(rf"[^\s{alphabet_}]")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def read_number(num: str) -> str:
|
| 43 |
+
if len(num) == 1:
|
| 44 |
+
return digits[int(num)]
|
| 45 |
+
elif len(num) == 2 and num.isdigit():
|
| 46 |
+
n = int(num)
|
| 47 |
+
end = digits[n % 10]
|
| 48 |
+
if n == 10:
|
| 49 |
+
return "mười"
|
| 50 |
+
if n % 10 == 5:
|
| 51 |
+
end = "lăm"
|
| 52 |
+
if n % 10 == 0:
|
| 53 |
+
return digits[n // 10] + " mươi"
|
| 54 |
+
elif n < 20:
|
| 55 |
+
return "mười " + end
|
| 56 |
+
else:
|
| 57 |
+
if n % 10 == 1:
|
| 58 |
+
end = "mốt"
|
| 59 |
+
return digits[n // 10] + " mươi " + end
|
| 60 |
+
elif len(num) == 3 and num.isdigit():
|
| 61 |
+
n = int(num)
|
| 62 |
+
if n % 100 == 0:
|
| 63 |
+
return digits[n // 100] + " trăm"
|
| 64 |
+
elif num[1] == "0":
|
| 65 |
+
return digits[n // 100] + " trăm lẻ " + digits[n % 100]
|
| 66 |
+
else:
|
| 67 |
+
return digits[n // 100] + " trăm " + read_number(num[1:])
|
| 68 |
+
elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
|
| 69 |
+
n = int(num)
|
| 70 |
+
n1 = n // 1000
|
| 71 |
+
return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
|
| 72 |
+
elif "," in num:
|
| 73 |
+
n1, n2 = num.split(",")
|
| 74 |
+
return read_number(n1) + " phẩy " + read_number(n2)
|
| 75 |
+
elif "." in num:
|
| 76 |
+
parts = num.split(".")
|
| 77 |
+
if len(parts) == 2:
|
| 78 |
+
if parts[1] == "000":
|
| 79 |
+
return read_number(parts[0]) + " ngàn"
|
| 80 |
+
elif parts[1].startswith("00"):
|
| 81 |
+
end = digits[int(parts[1][2:])]
|
| 82 |
+
return read_number(parts[0]) + " ngàn lẻ " + end
|
| 83 |
+
else:
|
| 84 |
+
return read_number(parts[0]) + " ngàn " + read_number(parts[1])
|
| 85 |
+
elif len(parts) == 3:
|
| 86 |
+
return (
|
| 87 |
+
read_number(parts[0])
|
| 88 |
+
+ " triệu "
|
| 89 |
+
+ read_number(parts[1])
|
| 90 |
+
+ " ngàn "
|
| 91 |
+
+ read_number(parts[2])
|
| 92 |
+
)
|
| 93 |
+
return num
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def normalize_text(text):
|
| 97 |
+
# lowercase
|
| 98 |
+
text = text.lower()
|
| 99 |
+
# unicode normalize
|
| 100 |
+
text = unicodedata.normalize("NFKC", text)
|
| 101 |
+
text = text.replace(".", ". ")
|
| 102 |
+
text = text.replace(",", ", ")
|
| 103 |
+
text = text.replace(";", "; ")
|
| 104 |
+
text = text.replace(":", ": ")
|
| 105 |
+
text = text.replace("!", "! ")
|
| 106 |
+
text = text.replace("?", "? ")
|
| 107 |
+
text = text.replace("(", "( ")
|
| 108 |
+
|
| 109 |
+
text = num_re.sub(r" \1 ", text)
|
| 110 |
+
words = text.split()
|
| 111 |
+
words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
|
| 112 |
+
text = " ".join(words)
|
| 113 |
+
|
| 114 |
+
# remove redundant spaces
|
| 115 |
+
text = re.sub(r"\s+", " ", text)
|
| 116 |
+
# remove leading and trailing spaces
|
| 117 |
+
text = text.strip()
|
| 118 |
+
return text
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def speak(text):
|
| 122 |
+
text = normalize_text(text)
|
| 123 |
+
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
| 124 |
+
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
| 125 |
+
return 24_000, y
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
title = "WaveGRU-TTS"
|
| 129 |
+
description = "WaveGRU text-to-speech demo."
|
| 130 |
+
|
| 131 |
+
gr.Interface(
|
| 132 |
+
fn=speak,
|
| 133 |
+
inputs="text",
|
| 134 |
+
examples=[
|
| 135 |
+
"Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
|
| 136 |
+
"Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du.",
|
| 137 |
+
"Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
|
| 138 |
+
"Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học lớn c���a Việt Nam trong thời phong kiến.",
|
| 139 |
+
"Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được, trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
|
| 140 |
+
],
|
| 141 |
+
outputs="audio",
|
| 142 |
+
title=title,
|
| 143 |
+
description=description,
|
| 144 |
+
theme="default",
|
| 145 |
+
).launch()
|
bazelisk-linux-amd64
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:231ec5ca8115e94c75a1f4fbada1a062b48822ca04f21f26e4cb1cd8973cd458
|
| 3 |
+
size 5152768
|
build_ext.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chmod +x ./bazelisk-linux-amd64
|
| 2 |
+
USE_BAZEL_VERSION=5.0.0 ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native
|
| 3 |
+
cp -f bazel-bin/wavegru_mod.so .
|
extract_tacotrons_model.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
|
| 5 |
+
dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb"))
|
| 6 |
+
del dic["optim_state_dict"]
|
| 7 |
+
dic = jax.device_get(dic)
|
| 8 |
+
pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb"))
|
extract_wavegru_model.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
|
| 5 |
+
dic = pickle.load(
|
| 6 |
+
open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb")
|
| 7 |
+
)
|
| 8 |
+
dic = jax.device_get(dic)
|
| 9 |
+
del dic["optim_state_dict"]
|
| 10 |
+
pickle.dump(
|
| 11 |
+
dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb")
|
| 12 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pax
|
| 8 |
+
|
| 9 |
+
# from text import english_cleaners
|
| 10 |
+
from utils import (
|
| 11 |
+
create_tacotron_model,
|
| 12 |
+
load_tacotron_ckpt,
|
| 13 |
+
load_tacotron_config,
|
| 14 |
+
load_wavegru_ckpt,
|
| 15 |
+
load_wavegru_config,
|
| 16 |
+
)
|
| 17 |
+
from wavegru import WaveGRU
|
| 18 |
+
|
| 19 |
+
# os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "./espeak/usr/lib/libespeak-ng.so.1.1.51"
|
| 20 |
+
# from phonemizer.backend import EspeakBackend
|
| 21 |
+
# backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_tacotron_model(alphabet_file, config_file, model_file):
|
| 25 |
+
"""load tacotron model to memory"""
|
| 26 |
+
with open(alphabet_file, "r", encoding="utf-8") as f:
|
| 27 |
+
alphabet = f.read().split("\n")
|
| 28 |
+
|
| 29 |
+
config = load_tacotron_config(config_file)
|
| 30 |
+
net = create_tacotron_model(config)
|
| 31 |
+
_, net, _ = load_tacotron_ckpt(net, None, model_file)
|
| 32 |
+
net = net.eval()
|
| 33 |
+
net = jax.device_put(net)
|
| 34 |
+
return alphabet, net, config
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=2400))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def text_to_mel(net, text, alphabet, config):
|
| 41 |
+
"""convert text to mel spectrogram"""
|
| 42 |
+
# text = english_cleaners(text)
|
| 43 |
+
# text = backend.phonemize([text], strip=True)[0]
|
| 44 |
+
text = text + config["END_CHARACTER"]
|
| 45 |
+
text = text + config["PAD"] * (100 - (len(text) % 100))
|
| 46 |
+
tokens = []
|
| 47 |
+
for c in text:
|
| 48 |
+
if c in alphabet:
|
| 49 |
+
tokens.append(alphabet.index(c))
|
| 50 |
+
tokens = jnp.array(tokens, dtype=jnp.int32)
|
| 51 |
+
mel = tacotron_inference_fn(net, tokens[None])
|
| 52 |
+
return mel
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_wavegru_net(config_file, model_file):
|
| 56 |
+
"""load wavegru to memory"""
|
| 57 |
+
config = load_wavegru_config(config_file)
|
| 58 |
+
net = WaveGRU(
|
| 59 |
+
mel_dim=config["mel_dim"],
|
| 60 |
+
rnn_dim=config["rnn_dim"],
|
| 61 |
+
upsample_factors=config["upsample_factors"],
|
| 62 |
+
has_linear_output=True,
|
| 63 |
+
)
|
| 64 |
+
_, net, _ = load_wavegru_ckpt(net, None, model_file)
|
| 65 |
+
net = net.eval()
|
| 66 |
+
net = jax.device_put(net)
|
| 67 |
+
return config, net
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def mel_to_wav(net, netcpp, mel, config):
|
| 74 |
+
"""convert mel to wav"""
|
| 75 |
+
if len(mel.shape) == 2:
|
| 76 |
+
mel = mel[None]
|
| 77 |
+
pad = config["num_pad_frames"] // 2 + 2
|
| 78 |
+
mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
|
| 79 |
+
ft = wavegru_inference(net, mel)
|
| 80 |
+
ft = jax.device_get(ft[0])
|
| 81 |
+
wav = netcpp.inference(ft, 1.0)
|
| 82 |
+
wav = np.array(wav)
|
| 83 |
+
wav = librosa.mu_expand(wav - 127, mu=255)
|
| 84 |
+
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
| 85 |
+
wav = wav * 2.0
|
| 86 |
+
wav = wav / max(1.0, np.max(np.abs(wav)))
|
| 87 |
+
wav = wav * 2**15
|
| 88 |
+
wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
|
| 89 |
+
wav = wav.astype(np.int16)
|
| 90 |
+
return wav
|
mono_tts_cbhg_small_0700000.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94a3cf9879f6c71ed21a6569f6f8167a8f4990e46b036b5f8196a16ea14fcb7e
|
| 3 |
+
size 53480857
|
packages.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libsndfile1-dev
|
| 2 |
+
make
|
| 3 |
+
autoconf
|
| 4 |
+
automake
|
| 5 |
+
libtool
|
| 6 |
+
gcc
|
| 7 |
+
pkg-config
|
pooch.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def os_cache(x):
|
| 2 |
+
return x
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create(*args, **kwargs):
|
| 6 |
+
class T:
|
| 7 |
+
def load_registry(self, *args, **kwargs):
|
| 8 |
+
return None
|
| 9 |
+
|
| 10 |
+
return T()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
inflect
|
| 2 |
+
jax
|
| 3 |
+
jaxlib
|
| 4 |
+
jinja2
|
| 5 |
+
librosa==0.9.0
|
| 6 |
+
numpy
|
| 7 |
+
pax3 @ git+https://github.com/ntt123/pax.git
|
| 8 |
+
pyyaml
|
| 9 |
+
toml
|
| 10 |
+
unidecode
|
| 11 |
+
phonemizer
|
| 12 |
+
gradio
|
| 13 |
+
setuptools
|
sparse_matmul/BUILD
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [internal] load placeholder
|
| 2 |
+
|
| 3 |
+
licenses(["notice"])
|
| 4 |
+
|
| 5 |
+
cc_library(
|
| 6 |
+
name = "sparse_matmul",
|
| 7 |
+
hdrs = [
|
| 8 |
+
"sparse_matmul.h",
|
| 9 |
+
],
|
| 10 |
+
visibility = ["//visibility:public"],
|
| 11 |
+
deps = [
|
| 12 |
+
"//sparse_matmul/compute:gru_gates",
|
| 13 |
+
"//sparse_matmul/layers:layer",
|
| 14 |
+
"//sparse_matmul/layers:matrix",
|
| 15 |
+
"//sparse_matmul/layers:utils",
|
| 16 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
| 17 |
+
"//sparse_matmul/numerics:types",
|
| 18 |
+
"//sparse_matmul/os:coop_threads",
|
| 19 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
| 20 |
+
], # internal :sparse_matmul deps placeholder
|
| 21 |
+
)
|
| 22 |
+
|
sparse_matmul/compute/BUILD
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Low-level computation code, including generic and architecture-specific
|
| 2 |
+
# variants.
|
| 3 |
+
|
| 4 |
+
licenses(["notice"])
|
| 5 |
+
|
| 6 |
+
cc_library(
|
| 7 |
+
name = "gru_gates",
|
| 8 |
+
srcs = [
|
| 9 |
+
"ar_inputs.h",
|
| 10 |
+
"gru_gates_arm.h",
|
| 11 |
+
"gru_gates_avx_fixed.h",
|
| 12 |
+
"gru_gates_generic.h",
|
| 13 |
+
],
|
| 14 |
+
hdrs = ["gru_gates.h"],
|
| 15 |
+
visibility = [
|
| 16 |
+
"//visibility:public",
|
| 17 |
+
],
|
| 18 |
+
deps = [
|
| 19 |
+
":matmul",
|
| 20 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
| 21 |
+
"//sparse_matmul/numerics:types",
|
| 22 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
| 23 |
+
],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
cc_library(
|
| 27 |
+
name = "kernels",
|
| 28 |
+
srcs = [
|
| 29 |
+
"kernels_arm.h",
|
| 30 |
+
"kernels_avx.h",
|
| 31 |
+
],
|
| 32 |
+
hdrs = [
|
| 33 |
+
"kernels_generic.h",
|
| 34 |
+
],
|
| 35 |
+
visibility = [
|
| 36 |
+
"//sparse_matmul:__subpackages__",
|
| 37 |
+
],
|
| 38 |
+
deps = [
|
| 39 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
| 40 |
+
"//sparse_matmul/numerics:types",
|
| 41 |
+
],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
cc_library(
|
| 45 |
+
name = "matmul",
|
| 46 |
+
srcs = [
|
| 47 |
+
"matmul_fixed_avx2.cc",
|
| 48 |
+
"matmul_fixed_avx2.h",
|
| 49 |
+
"matmul_generic.cc",
|
| 50 |
+
"matmul_generic.h",
|
| 51 |
+
],
|
| 52 |
+
hdrs = [
|
| 53 |
+
"matmul.h",
|
| 54 |
+
],
|
| 55 |
+
visibility = [
|
| 56 |
+
"//sparse_matmul:__subpackages__",
|
| 57 |
+
],
|
| 58 |
+
deps = [
|
| 59 |
+
"//sparse_matmul/numerics:types",
|
| 60 |
+
"@com_google_absl//absl/time",
|
| 61 |
+
],
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
cc_library(
|
| 65 |
+
name = "thread_bounds",
|
| 66 |
+
srcs = ["thread_bounds.cc"],
|
| 67 |
+
hdrs = ["thread_bounds.h"],
|
| 68 |
+
visibility = [
|
| 69 |
+
"//sparse_matmul:__subpackages__",
|
| 70 |
+
],
|
| 71 |
+
deps = [
|
| 72 |
+
"@com_google_glog//:glog",
|
| 73 |
+
],
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
cc_test(
|
| 77 |
+
name = "gru_gates_test",
|
| 78 |
+
size = "small",
|
| 79 |
+
srcs = [
|
| 80 |
+
"gru_gates_test.cc",
|
| 81 |
+
],
|
| 82 |
+
deps = [
|
| 83 |
+
":gru_gates",
|
| 84 |
+
"@com_google_absl//absl/memory",
|
| 85 |
+
"@com_google_absl//absl/types:span",
|
| 86 |
+
"@com_google_googletest//:gtest_main",
|
| 87 |
+
],
|
| 88 |
+
)
|
sparse_matmul/compute/ar_inputs.h
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
|
| 19 |
+
|
| 20 |
+
namespace csrblocksparse {
|
| 21 |
+
|
| 22 |
+
// Possible numbers of Autoregressive inputs.
|
| 23 |
+
// TODO(b/188702959): Generalize to any non-negative integer value?
|
| 24 |
+
enum class ARInputsMode {
|
| 25 |
+
// There are no autoregressive inputs. Inputs to the GRU gates are strictly
|
| 26 |
+
// from the gate-recurrent matmul and other unrelated inputs.
|
| 27 |
+
k0ARInputs,
|
| 28 |
+
// Two autoregressive inputs, such as coarse and fine for WaveRNN.
|
| 29 |
+
k2ARInputs,
|
| 30 |
+
// Three autoregressive inputs, such as prev coarse and fine plus current
|
| 31 |
+
// coarse for WaveRNN.
|
| 32 |
+
k3ARInputs,
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
} // namespace csrblocksparse
|
| 36 |
+
|
| 37 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
|
sparse_matmul/compute/gru_gates.h
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#include <vector>
|
| 22 |
+
|
| 23 |
+
// IWYU pragma: begin_exports
|
| 24 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
| 25 |
+
#include "sparse_matmul/compute/gru_gates_arm.h"
|
| 26 |
+
#include "sparse_matmul/compute/gru_gates_avx_fixed.h"
|
| 27 |
+
#include "sparse_matmul/compute/gru_gates_generic.h"
|
| 28 |
+
#include "sparse_matmul/compute/matmul.h"
|
| 29 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
| 30 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
| 31 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
| 32 |
+
// IWYU pragma: end_exports
|
| 33 |
+
|
| 34 |
+
namespace csrblocksparse {
|
| 35 |
+
|
| 36 |
+
// The master template is really a catch-all for the unimplemented cases to
|
| 37 |
+
// run the generics.
|
| 38 |
+
template <typename GRUStateType, typename InputType, typename SampleType = void>
|
| 39 |
+
class GruGates : public MatmulBase {
|
| 40 |
+
public:
|
| 41 |
+
using SampleWeightType = float;
|
| 42 |
+
static constexpr int kSIMDWidth = kGenericSIMDWidth;
|
| 43 |
+
|
| 44 |
+
// Generic GRU function covers all uses for WaveRNN-like architectures and
|
| 45 |
+
// conditioning.
|
| 46 |
+
// Controlled by template parameters thus:
|
| 47 |
+
// - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so
|
| 48 |
+
// |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|,
|
| 49 |
+
// |ar_2_weights| are ignored.
|
| 50 |
+
// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied
|
| 51 |
+
// by |ar_01_weights| and added to the (conditioning) input.
|
| 52 |
+
// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by
|
| 53 |
+
// |ar_2_weights| and added to the other two |ar_inputs| (and added to the
|
| 54 |
+
// conditioning input).
|
| 55 |
+
// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
|
| 56 |
+
// recurrent input that must be added to |*gru_recurrent_ptr|.
|
| 57 |
+
// - |num_replicas| determines the number of duplicates of the output to be
|
| 58 |
+
// written, separated by |replica_stride|.
|
| 59 |
+
// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
|
| 60 |
+
// thread.
|
| 61 |
+
//
|
| 62 |
+
// Previous state is read from |*gru_state_ptr| and the new state is written
|
| 63 |
+
// to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)).
|
| 64 |
+
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
| 65 |
+
bool kSplitGates = false>
|
| 66 |
+
void GruWithARInput(int start, int end, int state_size,
|
| 67 |
+
const InputType* gru_recurrent_ptr,
|
| 68 |
+
const InputType* input_ptr, GRUStateType* gru_state_ptr,
|
| 69 |
+
const SampleType* ar_sample0 = nullptr,
|
| 70 |
+
const SampleType* ar_sample1 = nullptr,
|
| 71 |
+
const SampleWeightType* ar_01_weights = nullptr,
|
| 72 |
+
int num_replicas = 1, int replica_stride = 0,
|
| 73 |
+
const SampleType* ar_sample2 = nullptr,
|
| 74 |
+
const SampleWeightType* ar_2_weights = nullptr,
|
| 75 |
+
const InputType* gru_recurrent_other_ptr = nullptr) {
|
| 76 |
+
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
|
| 77 |
+
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
|
| 78 |
+
kInputsMode, kSplitGates>(
|
| 79 |
+
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
|
| 80 |
+
input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0,
|
| 81 |
+
ar_sample1, ar_sample2);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// No AR inputs, no split gates, no batching, no replicated outputs.
|
| 85 |
+
// TODO(b/188702959): Redirect conditioning GRU here, removing code from
|
| 86 |
+
// gru_layer.h.
|
| 87 |
+
// Copy to specializations.
|
| 88 |
+
void PlainGru(int start, int end, int state_size,
|
| 89 |
+
const InputType* gru_recurrent_ptr, const InputType* input_ptr,
|
| 90 |
+
GRUStateType* gru_state_ptr) {
|
| 91 |
+
GruWithARInput<ARInputsMode::k0ARInputs>(
|
| 92 |
+
start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr);
|
| 93 |
+
}
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
#if defined __ARM_NEON || defined __aarch64__
|
| 97 |
+
// Partial specialization for float.
|
| 98 |
+
template <>
|
| 99 |
+
class GruGates<float, float, float> : public MatmulBase {
|
| 100 |
+
public:
|
| 101 |
+
static constexpr int kSIMDWidth = kNeonSIMDWidth;
|
| 102 |
+
|
| 103 |
+
// Generic GRU function covers all uses for WaveRNN-like architectures and
|
| 104 |
+
// conditioning.
|
| 105 |
+
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
| 106 |
+
bool kSplitGates = false>
|
| 107 |
+
void GruWithARInput(int start, int end, int state_size,
|
| 108 |
+
const float* gru_recurrent_data, const float* input_data,
|
| 109 |
+
float* gru_state_data, const float* ar_sample0 = nullptr,
|
| 110 |
+
const float* ar_sample1 = nullptr,
|
| 111 |
+
const float* ar_01_weights = nullptr,
|
| 112 |
+
int num_replicas = 1, int replica_stride = 0,
|
| 113 |
+
const float* ar_sample2 = nullptr,
|
| 114 |
+
const float* ar_2_weights = nullptr,
|
| 115 |
+
const float* gru_recurrent_other_data = nullptr) {
|
| 116 |
+
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
|
| 117 |
+
GoThroughGatesFloat<kInputsMode, kSplitGates>(
|
| 118 |
+
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
|
| 119 |
+
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
|
| 120 |
+
ar_sample1, ar_sample2);
|
| 121 |
+
}
|
| 122 |
+
};
|
| 123 |
+
#endif // defined __ARM_NEON || defined __aarch64__
|
| 124 |
+
|
| 125 |
+
// Partial specialization for fixed types. The sample weights are always float
|
| 126 |
+
// whatever the fixed type of the other weights.
|
| 127 |
+
template <int kGRUStateBits, int kInputBits, int kSampleBits>
|
| 128 |
+
class GruGates<fixed16<kGRUStateBits>, fixed32<kInputBits>,
|
| 129 |
+
fixed16<kSampleBits>> : public MatmulBase {
|
| 130 |
+
public:
|
| 131 |
+
#if defined __ARM_NEON || defined __aarch64__
|
| 132 |
+
static constexpr int kSIMDWidth = kNeonSIMDWidth;
|
| 133 |
+
#elif defined __AVX2__
|
| 134 |
+
static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2;
|
| 135 |
+
#else // Generic case.
|
| 136 |
+
static constexpr int kSIMDWidth = kGenericSIMDWidth;
|
| 137 |
+
#endif // __ARM_NEON || defined __aarch64__ / __AVX2__
|
| 138 |
+
|
| 139 |
+
using GRUStateType = fixed16<kGRUStateBits>;
|
| 140 |
+
using InputType = fixed32<kInputBits>;
|
| 141 |
+
using SampleType = fixed16<kSampleBits>;
|
| 142 |
+
using SampleWeightType = float;
|
| 143 |
+
static constexpr int kInputMantissaBits = InputType::kMantissaBits;
|
| 144 |
+
static constexpr int kSampleMantissaBits = SampleType::kMantissaBits;
|
| 145 |
+
static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits;
|
| 146 |
+
// Generic GRU function covers all uses for WaveRNN-like architectures and
|
| 147 |
+
// conditioning.
|
| 148 |
+
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
| 149 |
+
bool kSplitGates = false>
|
| 150 |
+
void GruWithARInput(int start, int end, int state_size,
|
| 151 |
+
const InputType* gru_recurrent_data,
|
| 152 |
+
const InputType* input_data, GRUStateType* gru_state_data,
|
| 153 |
+
const SampleType* ar_sample0 = nullptr,
|
| 154 |
+
const SampleType* ar_sample1 = nullptr,
|
| 155 |
+
const SampleWeightType* ar_01_weights = nullptr,
|
| 156 |
+
int num_replicas = 1, int replica_stride = 0,
|
| 157 |
+
const SampleType* ar_sample2 = nullptr,
|
| 158 |
+
const SampleWeightType* ar_2_weights = nullptr,
|
| 159 |
+
const InputType* gru_recurrent_other_data = nullptr) {
|
| 160 |
+
#if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__
|
| 161 |
+
const int32_t* gru_recurrent_ptr =
|
| 162 |
+
reinterpret_cast<const int32_t*>(gru_recurrent_data);
|
| 163 |
+
const int32_t* gru_recurrent_other_ptr =
|
| 164 |
+
reinterpret_cast<const int32_t*>(gru_recurrent_other_data);
|
| 165 |
+
const int32_t* input_ptr = reinterpret_cast<const int32_t*>(input_data);
|
| 166 |
+
int16_t* gru_state_ptr = reinterpret_cast<int16_t*>(gru_state_data);
|
| 167 |
+
#if defined __AVX2__
|
| 168 |
+
// The samples are fixed16, but we scale them up here and convert to float
|
| 169 |
+
// so that the product with the QR weights is always on the same scale as
|
| 170 |
+
// InputType, so we don't have to do any more scaling inside.
|
| 171 |
+
const float sample_factor = static_cast<float>(1 << kInputMantissaBits);
|
| 172 |
+
#else
|
| 173 |
+
const float sample_factor = 1.0f;
|
| 174 |
+
#endif
|
| 175 |
+
// AR sample 0 and 1 are packed into a pair because the QR weights are
|
| 176 |
+
// formatted with the weights interleaved for sample 0 and 1.
|
| 177 |
+
std::pair<float, float> ar_sample01;
|
| 178 |
+
float ar_sample2_float = 0.0f;
|
| 179 |
+
if (kInputsMode == ARInputsMode::k2ARInputs ||
|
| 180 |
+
kInputsMode == ARInputsMode::k3ARInputs) {
|
| 181 |
+
ar_sample01 = {static_cast<float>(*ar_sample0) * sample_factor,
|
| 182 |
+
static_cast<float>(*ar_sample1) * sample_factor};
|
| 183 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 184 |
+
ar_sample2_float = static_cast<float>(*ar_sample2) * sample_factor;
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
#if defined __AVX2__
|
| 188 |
+
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
|
| 189 |
+
GruGatesAVXFixed<kInputMantissaBits, kStateMantissaBits, kInputsMode,
|
| 190 |
+
kSplitGates>(
|
| 191 |
+
start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01,
|
| 192 |
+
ar_01_weights, num_replicas, replica_stride, &ar_sample2_float,
|
| 193 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
| 194 |
+
#else // ARM.
|
| 195 |
+
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
|
| 196 |
+
GoThroughGatesFixed<GRUStateType, InputType, kInputsMode, kSplitGates>(
|
| 197 |
+
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
|
| 198 |
+
input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01,
|
| 199 |
+
&ar_sample2_float);
|
| 200 |
+
#endif // __AVX2__ / ARM.
|
| 201 |
+
#else // Generic case.
|
| 202 |
+
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
|
| 203 |
+
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
|
| 204 |
+
kInputsMode, kSplitGates>(
|
| 205 |
+
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
|
| 206 |
+
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
|
| 207 |
+
ar_sample1, ar_sample2);
|
| 208 |
+
#endif // __ARM_NEON || defined __aarch64__ / __AVX2__
|
| 209 |
+
}
|
| 210 |
+
};
|
| 211 |
+
|
| 212 |
+
} // namespace csrblocksparse
|
| 213 |
+
|
| 214 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
|
sparse_matmul/compute/gru_gates_arm.h
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
|
| 19 |
+
|
| 20 |
+
#if defined __ARM_NEON || defined __aarch64__
|
| 21 |
+
#include <arm_neon.h>
|
| 22 |
+
#endif
|
| 23 |
+
#include <cstdint>
|
| 24 |
+
|
| 25 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
| 26 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
| 27 |
+
|
| 28 |
+
namespace csrblocksparse {
|
| 29 |
+
|
| 30 |
+
static constexpr int kNeonSIMDWidth = 4;
|
| 31 |
+
|
| 32 |
+
// ------ Scalar calculation --------
|
| 33 |
+
// See "Efficient Neural Audio Synthesis" for a description of the calculation.
|
| 34 |
+
// https://arxiv.org/abs/1802.08435
|
| 35 |
+
//
|
| 36 |
+
// NOTE:
|
| 37 |
+
// |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|,
|
| 38 |
+
// |coarse_at_sminus1|, |fine_at_sminus1|)
|
| 39 |
+
// |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|)
|
| 40 |
+
//
|
| 41 |
+
// CHEATSHEET:
|
| 42 |
+
// vld1q_f32 = load 4 32-bit floats
|
| 43 |
+
// vmulq_f32(a, b) : return a * b;
|
| 44 |
+
// vaddq_f32(a, b) : return a + b;
|
| 45 |
+
// vmlaq_f32(c, a, b) : return c + a * b;
|
| 46 |
+
// vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3)
|
| 47 |
+
// vsubq_f32(a, b) : return a - b;
|
| 48 |
+
// vst1q_f32 = store 4 32-bit floats
|
| 49 |
+
#if defined __ARM_NEON || defined __aarch64__
|
| 50 |
+
|
| 51 |
+
#if !defined __aarch64__
|
| 52 |
+
// Backport of vpaddq_f32 to ARM32.
|
| 53 |
+
inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) {
|
| 54 |
+
float32x2_t a10 = vget_low_f32(a);
|
| 55 |
+
float32x2_t a32 = vget_high_f32(a);
|
| 56 |
+
float32x2_t b10 = vget_low_f32(b);
|
| 57 |
+
float32x2_t b32 = vget_high_f32(b);
|
| 58 |
+
return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32));
|
| 59 |
+
}
|
| 60 |
+
#endif
|
| 61 |
+
|
| 62 |
+
template <ARInputsMode kInputsMode, bool SplitGates>
|
| 63 |
+
void GoThroughGatesFloat(int start, int end, const float* qr_ptr,
|
| 64 |
+
const float* gru_gates_ptr,
|
| 65 |
+
const float* gru_gates_other_ptr,
|
| 66 |
+
const float* conditioning_ptr, float* gru_h_ptr,
|
| 67 |
+
const float* w_hat, int proj_size,
|
| 68 |
+
const float* coarse_at_sminus1,
|
| 69 |
+
const float* fine_at_sminus1,
|
| 70 |
+
const float* coarse_at_s) {
|
| 71 |
+
// Increment all the pointers to save on pointer arithmetic in the loop.
|
| 72 |
+
conditioning_ptr += start;
|
| 73 |
+
gru_h_ptr += start;
|
| 74 |
+
gru_gates_ptr += start;
|
| 75 |
+
if (SplitGates) {
|
| 76 |
+
DCHECK_NE(gru_gates_other_ptr, nullptr);
|
| 77 |
+
gru_gates_other_ptr += start;
|
| 78 |
+
}
|
| 79 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 80 |
+
DCHECK_NE(qr_ptr, nullptr);
|
| 81 |
+
qr_ptr += 2 * start;
|
| 82 |
+
DCHECK_NE(coarse_at_sminus1, nullptr);
|
| 83 |
+
DCHECK_NE(fine_at_sminus1, nullptr);
|
| 84 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 85 |
+
DCHECK_NE(w_hat, nullptr);
|
| 86 |
+
DCHECK_NE(coarse_at_s, nullptr);
|
| 87 |
+
w_hat += start;
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
for (int i = start; i < end; i += kNeonSIMDWidth) {
|
| 91 |
+
float32x4_t reset = vld1q_f32(gru_gates_ptr);
|
| 92 |
+
float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size);
|
| 93 |
+
float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size);
|
| 94 |
+
float32x4_t qr_cell;
|
| 95 |
+
if (SplitGates) {
|
| 96 |
+
reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr));
|
| 97 |
+
update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size));
|
| 98 |
+
cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size));
|
| 99 |
+
}
|
| 100 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 101 |
+
// Setup the sample vector.
|
| 102 |
+
float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1);
|
| 103 |
+
sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1);
|
| 104 |
+
sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3);
|
| 105 |
+
|
| 106 |
+
// All auto types are float32x4_t, auto used to fit statements on one line
|
| 107 |
+
// for readability. Do two rows of QR at once.
|
| 108 |
+
auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample);
|
| 109 |
+
auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample);
|
| 110 |
+
auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
|
| 111 |
+
|
| 112 |
+
auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample);
|
| 113 |
+
auto qr_update_1 =
|
| 114 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample);
|
| 115 |
+
auto qr_update = vpaddq_f32(qr_update_0, qr_update_1);
|
| 116 |
+
|
| 117 |
+
auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample);
|
| 118 |
+
auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample);
|
| 119 |
+
qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
|
| 120 |
+
|
| 121 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 122 |
+
float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
|
| 123 |
+
qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
|
| 124 |
+
qr_update =
|
| 125 |
+
vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
|
| 126 |
+
qr_cell =
|
| 127 |
+
vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
|
| 128 |
+
}
|
| 129 |
+
reset = vaddq_f32(reset, qr_reset);
|
| 130 |
+
update = vaddq_f32(update, qr_update);
|
| 131 |
+
}
|
| 132 |
+
auto reset_conditioning = vld1q_f32(conditioning_ptr);
|
| 133 |
+
auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size);
|
| 134 |
+
auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size);
|
| 135 |
+
|
| 136 |
+
reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning));
|
| 137 |
+
update = fast_sigmoid(vaddq_f32(update, update_conditioning));
|
| 138 |
+
if (kInputsMode == ARInputsMode::k0ARInputs) {
|
| 139 |
+
cell = vmulq_f32(reset, cell);
|
| 140 |
+
} else {
|
| 141 |
+
cell = vmlaq_f32(qr_cell, reset, cell);
|
| 142 |
+
}
|
| 143 |
+
auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
|
| 144 |
+
|
| 145 |
+
auto prev_h = vld1q_f32(gru_h_ptr);
|
| 146 |
+
auto diff = vsubq_f32(prev_h, hbar);
|
| 147 |
+
auto new_h = vmlaq_f32(hbar, diff, update);
|
| 148 |
+
|
| 149 |
+
vst1q_f32(gru_h_ptr, new_h);
|
| 150 |
+
// Increment all the pointers.
|
| 151 |
+
conditioning_ptr += kNeonSIMDWidth;
|
| 152 |
+
gru_h_ptr += kNeonSIMDWidth;
|
| 153 |
+
gru_gates_ptr += kNeonSIMDWidth;
|
| 154 |
+
if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
|
| 155 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 156 |
+
qr_ptr += 2 * kNeonSIMDWidth;
|
| 157 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// This version should only be used if all of the 32-bit fixed point
|
| 163 |
+
// representations have the same number of mantissa bits.
|
| 164 |
+
// |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are
|
| 165 |
+
// formatted with the weights interleaved for sample 0 and 1. The two samples
|
| 166 |
+
// represent coarse and fine for WaveRNN.
|
| 167 |
+
template <typename GRUStateType, typename GRUMatMulOutType,
|
| 168 |
+
ARInputsMode kInputsMode, bool SplitGates>
|
| 169 |
+
void GoThroughGatesFixed(int start, int end, const float* qr_ptr,
|
| 170 |
+
const int32_t* gru_gates_ptr,
|
| 171 |
+
const int32_t* gru_gates_other_ptr,
|
| 172 |
+
const int32_t* conditioning_ptr, int16_t* gru_h_ptr,
|
| 173 |
+
const float* w_hat, int proj_size,
|
| 174 |
+
const std::pair<float, float>* ar_at_sminus1,
|
| 175 |
+
const float* coarse_at_s) {
|
| 176 |
+
// Increment all the pointers to save on pointer arithmetic in the loop.
|
| 177 |
+
conditioning_ptr += start;
|
| 178 |
+
gru_h_ptr += start;
|
| 179 |
+
gru_gates_ptr += start;
|
| 180 |
+
if (SplitGates) {
|
| 181 |
+
DCHECK_NE(gru_gates_other_ptr, nullptr);
|
| 182 |
+
gru_gates_other_ptr += start;
|
| 183 |
+
}
|
| 184 |
+
float32x4_t sample01;
|
| 185 |
+
float32x4_t w_sample;
|
| 186 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 187 |
+
DCHECK_NE(qr_ptr, nullptr);
|
| 188 |
+
qr_ptr += 2 * start;
|
| 189 |
+
DCHECK_NE(ar_at_sminus1, nullptr);
|
| 190 |
+
sample01 = vdupq_n_f32(ar_at_sminus1->first);
|
| 191 |
+
sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1);
|
| 192 |
+
sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3);
|
| 193 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 194 |
+
DCHECK_NE(w_hat, nullptr);
|
| 195 |
+
DCHECK_NE(coarse_at_s, nullptr);
|
| 196 |
+
w_hat += start;
|
| 197 |
+
w_sample = vdupq_n_f32(*coarse_at_s);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
for (int i = start; i < end; i += kNeonSIMDWidth) {
|
| 201 |
+
auto reset = vld1q_s32(gru_gates_ptr);
|
| 202 |
+
auto update = vld1q_s32(gru_gates_ptr + proj_size);
|
| 203 |
+
// vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32
|
| 204 |
+
auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size);
|
| 205 |
+
if (SplitGates) {
|
| 206 |
+
reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr));
|
| 207 |
+
update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size));
|
| 208 |
+
cell_int =
|
| 209 |
+
vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size));
|
| 210 |
+
}
|
| 211 |
+
float32x4_t cell =
|
| 212 |
+
vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits);
|
| 213 |
+
float32x4_t qr_cell;
|
| 214 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 215 |
+
// Do two rows of QR at once.
|
| 216 |
+
float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01);
|
| 217 |
+
float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01);
|
| 218 |
+
float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
|
| 219 |
+
|
| 220 |
+
float32x4_t qr_update_0 =
|
| 221 |
+
vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01);
|
| 222 |
+
float32x4_t qr_update_1 =
|
| 223 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01);
|
| 224 |
+
float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1);
|
| 225 |
+
|
| 226 |
+
float32x4_t qr_cell_0 =
|
| 227 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01);
|
| 228 |
+
float32x4_t qr_cell_1 =
|
| 229 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01);
|
| 230 |
+
qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
|
| 231 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 232 |
+
float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
|
| 233 |
+
qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
|
| 234 |
+
qr_update =
|
| 235 |
+
vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
|
| 236 |
+
qr_cell =
|
| 237 |
+
vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
|
| 238 |
+
}
|
| 239 |
+
reset = vaddq_s32(
|
| 240 |
+
reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits));
|
| 241 |
+
update = vaddq_s32(
|
| 242 |
+
update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits));
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
auto reset_conditioning = vld1q_s32(conditioning_ptr);
|
| 246 |
+
auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size);
|
| 247 |
+
float32x4_t cell_conditioning =
|
| 248 |
+
vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size),
|
| 249 |
+
GRUMatMulOutType::kMantissaBits);
|
| 250 |
+
|
| 251 |
+
float32x4_t reset_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
|
| 252 |
+
vaddq_s32(reset, reset_conditioning));
|
| 253 |
+
float32x4_t update_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
|
| 254 |
+
vaddq_s32(update, update_conditioning));
|
| 255 |
+
if (kInputsMode == ARInputsMode::k0ARInputs) {
|
| 256 |
+
cell = vmulq_f32(reset_f32, cell);
|
| 257 |
+
} else {
|
| 258 |
+
cell = vmlaq_f32(qr_cell, reset_f32, cell);
|
| 259 |
+
}
|
| 260 |
+
float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
|
| 261 |
+
|
| 262 |
+
float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)),
|
| 263 |
+
GRUStateType::kMantissaBits);
|
| 264 |
+
float32x4_t diff = vsubq_f32(prev_h, hbar);
|
| 265 |
+
float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32);
|
| 266 |
+
|
| 267 |
+
// vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point
|
| 268 |
+
// vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to
|
| 269 |
+
// convert a 32-bit fixed point value to a 16-bit fixed point value
|
| 270 |
+
vst1_s16(gru_h_ptr,
|
| 271 |
+
vqrshrn_n_s32(
|
| 272 |
+
vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16));
|
| 273 |
+
// Increment all the pointers.
|
| 274 |
+
conditioning_ptr += kNeonSIMDWidth;
|
| 275 |
+
gru_h_ptr += kNeonSIMDWidth;
|
| 276 |
+
gru_gates_ptr += kNeonSIMDWidth;
|
| 277 |
+
if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
|
| 278 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 279 |
+
qr_ptr += 2 * kNeonSIMDWidth;
|
| 280 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
|
| 281 |
+
}
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
#endif // defined __ARM_NEON || defined __aarch64__
|
| 285 |
+
|
| 286 |
+
} // namespace csrblocksparse
|
| 287 |
+
|
| 288 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
|
sparse_matmul/compute/gru_gates_avx_fixed.h
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#if defined __AVX2__
|
| 22 |
+
#include <immintrin.h>
|
| 23 |
+
#endif
|
| 24 |
+
#include <vector>
|
| 25 |
+
|
| 26 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
| 27 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
| 28 |
+
|
| 29 |
+
namespace csrblocksparse {
|
| 30 |
+
|
| 31 |
+
#if defined __AVX2__
|
| 32 |
+
|
| 33 |
+
constexpr int kAVX2SIMDWidth = 8;
|
| 34 |
+
|
| 35 |
+
// Loads 8x fixed32 from |ptr0| and adds to |input|.
|
| 36 |
+
// If |kTwoInputs|, also loads from |ptr1| and adds that as well.
|
| 37 |
+
// Returns the 2 or 3-way sum.
|
| 38 |
+
template <bool kTwoInputs>
|
| 39 |
+
inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1,
|
| 40 |
+
const __m256i& input) {
|
| 41 |
+
__m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
|
| 42 |
+
if (kTwoInputs) {
|
| 43 |
+
__m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
|
| 44 |
+
data0 = _mm256_add_epi32(data0, data1);
|
| 45 |
+
}
|
| 46 |
+
return _mm256_add_epi32(data0, input);
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
// Loads 8x fixed32 from ptr0.
|
| 50 |
+
// If |kTwoInputs|, also loads from |ptr1| and adds.
|
| 51 |
+
// Multiplies the loaded values by the factor and adds to |input|, which also
|
| 52 |
+
// is converted to float.
|
| 53 |
+
// Returns the sum.
|
| 54 |
+
template <bool kTwoInputs>
|
| 55 |
+
inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1,
|
| 56 |
+
const __m256& float_factor,
|
| 57 |
+
const __m256& input) {
|
| 58 |
+
__m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
|
| 59 |
+
if (kTwoInputs) {
|
| 60 |
+
__m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
|
| 61 |
+
data0 = _mm256_add_epi32(data0, data1);
|
| 62 |
+
}
|
| 63 |
+
__m256 float_result = _mm256_cvtepi32_ps(data0);
|
| 64 |
+
float_result = _mm256_mul_ps(float_result, float_factor);
|
| 65 |
+
return _mm256_add_ps(float_result, input);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by
|
| 69 |
+
// |input_pairs|, likewise formatted as 8x floats, alternating between the two
|
| 70 |
+
// AR inputs and sums each pair of results, making 8x float results.
|
| 71 |
+
// If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by
|
| 72 |
+
// |third_input|, which must be formatted as 8x float. The second product is
|
| 73 |
+
// added to the previous result.
|
| 74 |
+
// Returns the sum added to |accumulator|.
|
| 75 |
+
template <bool kThreeInputs>
|
| 76 |
+
inline __m256 MultiplyAddFloat(const __m256& input_pairs,
|
| 77 |
+
const __m256& third_input, const float* ptr0_1,
|
| 78 |
+
const float* ptr2, const __m256& accumulator) {
|
| 79 |
+
__m256 data_pair0 = _mm256_load_ps(ptr0_1);
|
| 80 |
+
__m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8);
|
| 81 |
+
data_pair0 = _mm256_mul_ps(data_pair0, input_pairs);
|
| 82 |
+
data_pair1 = _mm256_mul_ps(data_pair1, input_pairs);
|
| 83 |
+
data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1);
|
| 84 |
+
// Swap the middle 2 64 bit pairs to correct the hadd result.
|
| 85 |
+
data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8);
|
| 86 |
+
if (kThreeInputs) {
|
| 87 |
+
// Load 256 bits (8 x float) of data, then multiply-accumulate.
|
| 88 |
+
data_pair1 = _mm256_load_ps(ptr2);
|
| 89 |
+
data_pair1 = _mm256_mul_ps(data_pair1, third_input);
|
| 90 |
+
data_pair0 = _mm256_add_ps(data_pair0, data_pair1);
|
| 91 |
+
}
|
| 92 |
+
// Add conditioning.
|
| 93 |
+
return _mm256_add_ps(data_pair0, accumulator);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Processes the tanh and the final combination, returns the new GRU state.
|
| 97 |
+
template <int kInputMantissaBits, int kStateMantissaBits, bool kSplitGates>
|
| 98 |
+
inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1,
|
| 99 |
+
const __m256& reset0, const __m256& reset1,
|
| 100 |
+
const __m256& update0, const __m256& update1,
|
| 101 |
+
const int32_t* gate_ptr,
|
| 102 |
+
const int32_t* gate_other_ptr,
|
| 103 |
+
const void* gru_h_ptr) {
|
| 104 |
+
// Multiply the cell gru output and the reset.
|
| 105 |
+
__m256 float_gru0 = LoadMultiplyAddToFloat<kSplitGates>(
|
| 106 |
+
gate_ptr, gate_other_ptr, reset0, cell0);
|
| 107 |
+
__m256 float_gru1 = LoadMultiplyAddToFloat<kSplitGates>(
|
| 108 |
+
gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1,
|
| 109 |
+
cell1);
|
| 110 |
+
// Compute tanh on the result.
|
| 111 |
+
__m256 hbar0, hbar1;
|
| 112 |
+
float_tanh_float<kInputMantissaBits, TM_ORDER4_FLOAT>(float_gru0, float_gru1,
|
| 113 |
+
hbar0, hbar1);
|
| 114 |
+
// Load the 16-bit previous gru state and update.
|
| 115 |
+
__m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr));
|
| 116 |
+
__m256 state_factor =
|
| 117 |
+
_mm256_set1_ps(1.0f / (static_cast<float>(1 << kStateMantissaBits)));
|
| 118 |
+
float_gru0 =
|
| 119 |
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru)));
|
| 120 |
+
float_gru1 = _mm256_cvtepi32_ps(
|
| 121 |
+
_mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1)));
|
| 122 |
+
float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
|
| 123 |
+
float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
|
| 124 |
+
float_gru0 = _mm256_sub_ps(float_gru0, hbar0);
|
| 125 |
+
float_gru1 = _mm256_sub_ps(float_gru1, hbar1);
|
| 126 |
+
float_gru0 = _mm256_mul_ps(float_gru0, update0);
|
| 127 |
+
float_gru1 = _mm256_mul_ps(float_gru1, update1);
|
| 128 |
+
state_factor = _mm256_set1_ps(static_cast<float>(1 << kStateMantissaBits));
|
| 129 |
+
float_gru0 = _mm256_add_ps(float_gru0, hbar0);
|
| 130 |
+
float_gru1 = _mm256_add_ps(float_gru1, hbar1);
|
| 131 |
+
float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
|
| 132 |
+
float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
|
| 133 |
+
return PackFloatsToFixed16(float_gru0, float_gru1);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and
|
| 137 |
+
// combines with |input| and |gates*|.
|
| 138 |
+
// With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies
|
| 139 |
+
// by |paired_ar|, likewise formatted as 8x float, but scaled such that the
|
| 140 |
+
// product with pair_weights is on the same scale as |*input| and |*gates0|,
|
| 141 |
+
// and sums each pair result, making 8x float results.
|
| 142 |
+
// If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by
|
| 143 |
+
// |third_ar|, which must be formatted as 8x scaled floats. The second product
|
| 144 |
+
// is added to the previous result.
|
| 145 |
+
// Inputs, 8x fixed32 are loaded from |input|, and added to the total.
|
| 146 |
+
// Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as
|
| 147 |
+
// well.
|
| 148 |
+
// Returns the total sum as a float, but on the scale of |*input|.
|
| 149 |
+
template <bool kTwoGates, ARInputsMode kInputsMode>
|
| 150 |
+
inline __m256 GruInput32ToFloat(const __m256& paired_ar,
|
| 151 |
+
const __m256& third_ar,
|
| 152 |
+
const float* pair_weights,
|
| 153 |
+
const float* third_weights,
|
| 154 |
+
const int32_t* gates0, const int32_t* gates1,
|
| 155 |
+
const int32_t* input) {
|
| 156 |
+
__m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input));
|
| 157 |
+
data32 = LoadAndAddFixed32<kTwoGates>(gates0, gates1, data32);
|
| 158 |
+
__m256 float_data = _mm256_cvtepi32_ps(data32);
|
| 159 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 160 |
+
float_data = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
|
| 161 |
+
paired_ar, third_ar, pair_weights, third_weights, float_data);
|
| 162 |
+
}
|
| 163 |
+
return float_data;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// Generic GRU gates function controlled by template parameters thus:
|
| 167 |
+
// - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|.
|
| 168 |
+
// - |kStateBits|: the mantissa_bits in |*gru_state_ptr|.
|
| 169 |
+
// - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so
|
| 170 |
+
// |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are
|
| 171 |
+
// ignored.
|
| 172 |
+
// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by
|
| 173 |
+
// |ar_01_weights| and added to the (conditioning) input.
|
| 174 |
+
// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights|
|
| 175 |
+
// and added to the other two AR inputs (and added to the conditioning input).
|
| 176 |
+
// - |kReplicas| determines the number of duplicates of the output to be
|
| 177 |
+
// written, separated by |replica_stride|. If zero, then the number of
|
| 178 |
+
// replicas is variable and taken from the |replicas| argument.
|
| 179 |
+
// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
|
| 180 |
+
// recurrent input that must be added to |*gru_recurrent_ptr|.
|
| 181 |
+
// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
|
| 182 |
+
// thread.
|
| 183 |
+
//
|
| 184 |
+
// Previous state is read from |*gru_state_ptr| and the new state is written to
|
| 185 |
+
// *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]).
|
| 186 |
+
template <int kInputBits, int kStateBits,
|
| 187 |
+
ARInputsMode kInputsMode = ARInputsMode::k0ARInputs,
|
| 188 |
+
int kReplicas = 1, bool kSplitGates = false>
|
| 189 |
+
inline void GruGatesTemplate(
|
| 190 |
+
int start, int end, int state_size, int replicas, int replica_stride,
|
| 191 |
+
const int32_t* gru_recurrent_ptr, const int32_t* input_ptr,
|
| 192 |
+
const std::pair<float, float>* ar_sample01, const float* ar_01_weights,
|
| 193 |
+
const float* ar_sample2, const float* ar_2_weights,
|
| 194 |
+
const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
|
| 195 |
+
constexpr int kQRIncrement = kAVX2SIMDWidth;
|
| 196 |
+
// Increment all the pointers to save on pointer arithmetic in the loop.
|
| 197 |
+
input_ptr += start;
|
| 198 |
+
gru_state_ptr += start;
|
| 199 |
+
gru_recurrent_ptr += start;
|
| 200 |
+
if (kSplitGates) gru_recurrent_other_ptr += start;
|
| 201 |
+
__m256 ar_2_inputs, ar_3rd_input;
|
| 202 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 203 |
+
ar_01_weights += 2 * start;
|
| 204 |
+
ar_2_inputs = _mm256_castsi256_ps(
|
| 205 |
+
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(ar_sample01)));
|
| 206 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 207 |
+
ar_2_weights += start;
|
| 208 |
+
ar_3rd_input = _mm256_set1_ps(*ar_sample2);
|
| 209 |
+
} else {
|
| 210 |
+
ar_3rd_input = {};
|
| 211 |
+
}
|
| 212 |
+
} else {
|
| 213 |
+
ar_2_inputs = {};
|
| 214 |
+
ar_3rd_input = {};
|
| 215 |
+
}
|
| 216 |
+
// The transcendentals handle 2x registers of data at once, so we have to do
|
| 217 |
+
// everything in duplicate.
|
| 218 |
+
for (int i = start; i < end; i += kQRIncrement * 2) {
|
| 219 |
+
// Load 8 pairs of fixed16s for each of reset, update and cell.
|
| 220 |
+
__m256 reset0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
| 221 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights,
|
| 222 |
+
gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr);
|
| 223 |
+
__m256 reset1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
| 224 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement,
|
| 225 |
+
ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth,
|
| 226 |
+
gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth);
|
| 227 |
+
float_sigmoid_float<kInputBits>(reset0, reset1);
|
| 228 |
+
__m256 update0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
| 229 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size,
|
| 230 |
+
ar_2_weights + state_size, gru_recurrent_ptr + state_size,
|
| 231 |
+
gru_recurrent_other_ptr + state_size, input_ptr + state_size);
|
| 232 |
+
__m256 update1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
| 233 |
+
ar_2_inputs, ar_3rd_input,
|
| 234 |
+
ar_01_weights + 2 * state_size + 2 * kQRIncrement,
|
| 235 |
+
ar_2_weights + state_size + kQRIncrement,
|
| 236 |
+
gru_recurrent_ptr + state_size + kAVX2SIMDWidth,
|
| 237 |
+
gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth,
|
| 238 |
+
input_ptr + state_size + kAVX2SIMDWidth);
|
| 239 |
+
float_sigmoid_float<kInputBits>(update0, update1);
|
| 240 |
+
__m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256(
|
| 241 |
+
reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size)));
|
| 242 |
+
__m256 cell1 =
|
| 243 |
+
_mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>(
|
| 244 |
+
input_ptr + 2 * state_size + kAVX2SIMDWidth)));
|
| 245 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 246 |
+
cell0 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
|
| 247 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size,
|
| 248 |
+
ar_2_weights + 2 * state_size, cell0);
|
| 249 |
+
cell1 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
|
| 250 |
+
ar_2_inputs, ar_3rd_input,
|
| 251 |
+
ar_01_weights + 4 * state_size + 2 * kQRIncrement,
|
| 252 |
+
ar_2_weights + 2 * state_size + kQRIncrement, cell1);
|
| 253 |
+
}
|
| 254 |
+
__m256i gru_state = GRUComputeState<kInputBits, kStateBits, kSplitGates>(
|
| 255 |
+
cell0, cell1, reset0, reset1, update0, update1,
|
| 256 |
+
gru_recurrent_ptr + 2 * state_size,
|
| 257 |
+
gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr);
|
| 258 |
+
if (kReplicas > 0) {
|
| 259 |
+
// With |kReplicas| a template parameter, the compiler will unroll the
|
| 260 |
+
// loop.
|
| 261 |
+
for (int j = 0; j < kReplicas; ++j) {
|
| 262 |
+
_mm256_store_si256(
|
| 263 |
+
reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
|
| 264 |
+
gru_state);
|
| 265 |
+
}
|
| 266 |
+
} else {
|
| 267 |
+
// This loop will not unroll as replicas is variable.
|
| 268 |
+
for (int j = 0; j < replicas; ++j) {
|
| 269 |
+
_mm256_store_si256(
|
| 270 |
+
reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
|
| 271 |
+
gru_state);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
// Increment all the pointers.
|
| 275 |
+
input_ptr += 2 * kAVX2SIMDWidth;
|
| 276 |
+
gru_state_ptr += 2 * kAVX2SIMDWidth;
|
| 277 |
+
gru_recurrent_ptr += 2 * kAVX2SIMDWidth;
|
| 278 |
+
if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth;
|
| 279 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
| 280 |
+
ar_01_weights += 4 * kQRIncrement;
|
| 281 |
+
if (kInputsMode == ARInputsMode::k3ARInputs)
|
| 282 |
+
ar_2_weights += 2 * kQRIncrement;
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// Dispatches calls to the GruGatesTemplate function above converting the
|
| 288 |
+
// replicas variable argument to a template parameter to allow the compiler to
|
| 289 |
+
// unroll the write loop.
|
| 290 |
+
// |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are
|
| 291 |
+
// formatted with the weights interleaved for sample 0 and 1. The two samples
|
| 292 |
+
// represent coarse and fine for WaveRNN.
|
| 293 |
+
template <int kInputBits, int kStateBits,
|
| 294 |
+
ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
| 295 |
+
bool kSplitGates = false>
|
| 296 |
+
inline void GruGatesAVXFixed(
|
| 297 |
+
int start, int end, int state_size, const int32_t* gru_recurrent_ptr,
|
| 298 |
+
const int32_t* input_ptr, const std::pair<float, float>* ar_sample01,
|
| 299 |
+
const float* ar_01_weights, int num_replicas, int replica_stride,
|
| 300 |
+
const float* ar_sample2, const float* ar_2_weights,
|
| 301 |
+
const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
|
| 302 |
+
// Convert the number of replicas from a variable to a template parameter
|
| 303 |
+
// with a switch. This enables the compiler to unroll the loop for
|
| 304 |
+
// the write, making it faster for common numbers of threads.
|
| 305 |
+
switch (num_replicas) {
|
| 306 |
+
case 1:
|
| 307 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/1,
|
| 308 |
+
kSplitGates>(
|
| 309 |
+
start, end, state_size, num_replicas, replica_stride,
|
| 310 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
| 311 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
| 312 |
+
break;
|
| 313 |
+
case 2:
|
| 314 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/2,
|
| 315 |
+
kSplitGates>(
|
| 316 |
+
start, end, state_size, num_replicas, replica_stride,
|
| 317 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
| 318 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
| 319 |
+
break;
|
| 320 |
+
case 4:
|
| 321 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/4,
|
| 322 |
+
kSplitGates>(
|
| 323 |
+
start, end, state_size, num_replicas, replica_stride,
|
| 324 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
| 325 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
| 326 |
+
break;
|
| 327 |
+
case 6:
|
| 328 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/6,
|
| 329 |
+
kSplitGates>(
|
| 330 |
+
start, end, state_size, num_replicas, replica_stride,
|
| 331 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
| 332 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
| 333 |
+
break;
|
| 334 |
+
default:
|
| 335 |
+
// Zero |kReplicas| tells the function to use the |num_replicas| variable.
|
| 336 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/0,
|
| 337 |
+
kSplitGates>(
|
| 338 |
+
start, end, state_size, num_replicas, replica_stride,
|
| 339 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
| 340 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
#endif // __AVX2__
|
| 345 |
+
|
| 346 |
+
} // namespace csrblocksparse
|
| 347 |
+
|
| 348 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
|
sparse_matmul/compute/gru_gates_generic.h
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
|
| 19 |
+
|
| 20 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
| 21 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
| 22 |
+
|
| 23 |
+
namespace csrblocksparse {
|
| 24 |
+
|
| 25 |
+
constexpr int kGenericSIMDWidth = 4;
|
| 26 |
+
|
| 27 |
+
// TODO(b/188702959): Rename arguments to match gru_gates.h.
|
| 28 |
+
template <typename GRUStateType, typename GRUMatMulOutType, typename QR_W_Type,
|
| 29 |
+
typename SampleType, ARInputsMode kInputsMode,
|
| 30 |
+
bool SplitGates = false>
|
| 31 |
+
void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr,
|
| 32 |
+
const GRUMatMulOutType* gru_gates_ptr,
|
| 33 |
+
const GRUMatMulOutType* gru_gates_other_ptr,
|
| 34 |
+
const GRUMatMulOutType* conditioning_ptr,
|
| 35 |
+
GRUStateType* gru_h_ptr, const QR_W_Type* w_hat,
|
| 36 |
+
int proj_size, const SampleType* coarse_at_sminus1,
|
| 37 |
+
const SampleType* fine_at_sminus1,
|
| 38 |
+
const SampleType* coarse_at_s = nullptr) {
|
| 39 |
+
float qr_cell = 0.0f, reset, update, cell;
|
| 40 |
+
for (int i = start; i < end; ++i) {
|
| 41 |
+
if (kInputsMode == ARInputsMode::k0ARInputs) {
|
| 42 |
+
reset = static_cast<float>(gru_gates_ptr[i]);
|
| 43 |
+
update = static_cast<float>(gru_gates_ptr[proj_size + i]);
|
| 44 |
+
} else {
|
| 45 |
+
float qr_c_reset = static_cast<float>(qr_ptr[2 * i + 0]);
|
| 46 |
+
float qr_f_reset = static_cast<float>(qr_ptr[2 * i + 1]);
|
| 47 |
+
float qr_c_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 0]);
|
| 48 |
+
float qr_f_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 1]);
|
| 49 |
+
float qr_c_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 0]);
|
| 50 |
+
float qr_f_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 1]);
|
| 51 |
+
float w_hat_i_reset = 0.0f;
|
| 52 |
+
float w_hat_i_update = 0.0f;
|
| 53 |
+
float w_hat_i_cell = 0.0f;
|
| 54 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 55 |
+
w_hat_i_reset = static_cast<float>(w_hat[i]);
|
| 56 |
+
w_hat_i_update = static_cast<float>(w_hat[proj_size + i]);
|
| 57 |
+
w_hat_i_cell = static_cast<float>(w_hat[2 * proj_size + i]);
|
| 58 |
+
}
|
| 59 |
+
float coarse = static_cast<float>(coarse_at_sminus1[0]);
|
| 60 |
+
float fine = static_cast<float>(fine_at_sminus1[0]);
|
| 61 |
+
reset = qr_c_reset * coarse + qr_f_reset * fine;
|
| 62 |
+
update = qr_c_update * coarse + qr_f_update * fine;
|
| 63 |
+
qr_cell = qr_c_cell * coarse + qr_f_cell * fine;
|
| 64 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
| 65 |
+
float coarse = static_cast<float>(coarse_at_s[0]);
|
| 66 |
+
reset += w_hat_i_reset * coarse;
|
| 67 |
+
update += w_hat_i_update * coarse;
|
| 68 |
+
qr_cell += w_hat_i_cell * coarse;
|
| 69 |
+
}
|
| 70 |
+
reset += static_cast<float>(gru_gates_ptr[i]);
|
| 71 |
+
update += static_cast<float>(gru_gates_ptr[proj_size + i]);
|
| 72 |
+
}
|
| 73 |
+
cell = static_cast<float>(gru_gates_ptr[2 * proj_size + i]);
|
| 74 |
+
if (SplitGates) {
|
| 75 |
+
reset += static_cast<float>(gru_gates_other_ptr[i]);
|
| 76 |
+
update += static_cast<float>(gru_gates_other_ptr[proj_size + i]);
|
| 77 |
+
cell += static_cast<float>(gru_gates_other_ptr[2 * proj_size + i]);
|
| 78 |
+
}
|
| 79 |
+
float reset_conditioning = static_cast<float>(conditioning_ptr[i]);
|
| 80 |
+
float update_conditioning =
|
| 81 |
+
static_cast<float>(conditioning_ptr[proj_size + i]);
|
| 82 |
+
float cell_conditioning =
|
| 83 |
+
static_cast<float>(conditioning_ptr[2 * proj_size + i]);
|
| 84 |
+
reset = fast_sigmoid(reset + reset_conditioning);
|
| 85 |
+
update = fast_sigmoid(update + update_conditioning);
|
| 86 |
+
float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning);
|
| 87 |
+
int h_index = i;
|
| 88 |
+
float prev_h = static_cast<float>(gru_h_ptr[h_index]);
|
| 89 |
+
float diff = prev_h - hbar;
|
| 90 |
+
float new_h = hbar + diff * update;
|
| 91 |
+
gru_h_ptr[h_index] = static_cast<GRUStateType>(new_h);
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
} // namespace csrblocksparse
|
| 96 |
+
|
| 97 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
|
sparse_matmul/compute/gru_gates_test.cc
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include "sparse_matmul/compute/gru_gates.h"
|
| 16 |
+
|
| 17 |
+
#include <cstdint>
|
| 18 |
+
#include <cstring>
|
| 19 |
+
#include <numeric>
|
| 20 |
+
|
| 21 |
+
#include "absl/memory/memory.h"
|
| 22 |
+
#include "absl/types/span.h"
|
| 23 |
+
#include "gmock/gmock.h"
|
| 24 |
+
#include "gtest/gtest.h"
|
| 25 |
+
|
| 26 |
+
namespace {
|
| 27 |
+
|
| 28 |
+
using csrblocksparse::ARInputsMode;
|
| 29 |
+
|
| 30 |
+
template <typename GRUStateType, typename InputType, typename SampleType = void,
|
| 31 |
+
csrblocksparse::ARInputsMode kInputsMode, bool kSplitGates>
|
| 32 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> TestGruGates() {
|
| 33 |
+
using SampleWeightType = float;
|
| 34 |
+
constexpr int kStateSize = 16;
|
| 35 |
+
csrblocksparse::CacheAlignedVector<SampleWeightType> qr(6 * kStateSize);
|
| 36 |
+
csrblocksparse::CacheAlignedVector<SampleWeightType> w(3 * kStateSize);
|
| 37 |
+
csrblocksparse::CacheAlignedVector<InputType> gru_gates(3 * kStateSize);
|
| 38 |
+
csrblocksparse::CacheAlignedVector<InputType> gru_other_gates(3 * kStateSize);
|
| 39 |
+
csrblocksparse::CacheAlignedVector<InputType> conditioning(3 * kStateSize);
|
| 40 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> gru_h(kStateSize);
|
| 41 |
+
csrblocksparse::GruGates<GRUStateType, InputType, SampleType> gru_gates_impl;
|
| 42 |
+
const SampleType kCoarseAtSMinus1(0.03f);
|
| 43 |
+
const SampleType kFineAtSMinus1(0.07f);
|
| 44 |
+
const SampleType kCoarseAtS(-0.02f);
|
| 45 |
+
|
| 46 |
+
qr.FillOnes();
|
| 47 |
+
w.FillOnes();
|
| 48 |
+
gru_gates.FillRandom();
|
| 49 |
+
gru_other_gates.FillRandom();
|
| 50 |
+
conditioning.FillRandom();
|
| 51 |
+
gru_h.FillZero();
|
| 52 |
+
|
| 53 |
+
gru_gates_impl.template GruWithARInput<kInputsMode, kSplitGates>(
|
| 54 |
+
/*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(),
|
| 55 |
+
conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1,
|
| 56 |
+
qr.data(),
|
| 57 |
+
/*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(),
|
| 58 |
+
gru_other_gates.data());
|
| 59 |
+
return gru_h;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) {
|
| 63 |
+
// If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
|
| 64 |
+
// will also need to change.
|
| 65 |
+
const std::vector<float> kGoldenValues = {
|
| 66 |
+
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f,
|
| 67 |
+
0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f};
|
| 68 |
+
csrblocksparse::CacheAlignedVector<float> gru_h =
|
| 69 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
| 70 |
+
/*kSplitGates=*/true>();
|
| 71 |
+
|
| 72 |
+
ASSERT_EQ(kGoldenValues.size(), gru_h.size());
|
| 73 |
+
for (int i = 0; i < gru_h.size(); ++i) {
|
| 74 |
+
EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
TEST(GruGates, FloatWaveRNNFineMatchesGolden) {
|
| 79 |
+
// If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
|
| 80 |
+
// will also need to change.
|
| 81 |
+
const std::vector<float> kGoldenValues = {
|
| 82 |
+
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f,
|
| 83 |
+
0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f};
|
| 84 |
+
csrblocksparse::CacheAlignedVector<float> gru_h =
|
| 85 |
+
TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
|
| 86 |
+
/*kSplitGates=*/true>();
|
| 87 |
+
|
| 88 |
+
ASSERT_EQ(kGoldenValues.size(), gru_h.size());
|
| 89 |
+
for (int i = 0; i < gru_h.size(); ++i) {
|
| 90 |
+
EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) {
|
| 95 |
+
// If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
|
| 96 |
+
// will also need to change.
|
| 97 |
+
const std::vector<float> kGoldenValues = {
|
| 98 |
+
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f,
|
| 99 |
+
0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f};
|
| 100 |
+
csrblocksparse::CacheAlignedVector<float> gru_h =
|
| 101 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
| 102 |
+
/*kSplitGates=*/false>();
|
| 103 |
+
|
| 104 |
+
ASSERT_EQ(kGoldenValues.size(), gru_h.size());
|
| 105 |
+
for (int i = 0; i < gru_h.size(); ++i) {
|
| 106 |
+
EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) {
|
| 111 |
+
using GRUMatMulOutType = csrblocksparse::fixed32<11>;
|
| 112 |
+
using GRUStateType = csrblocksparse::fixed16<2>;
|
| 113 |
+
using SampleType = csrblocksparse::fixed16<0>;
|
| 114 |
+
csrblocksparse::CacheAlignedVector<float> float_gru_h =
|
| 115 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
| 116 |
+
/*kSplitGates=*/true>();
|
| 117 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
|
| 118 |
+
TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
|
| 119 |
+
ARInputsMode::k2ARInputs, /*kSplitGates=*/true>();
|
| 120 |
+
|
| 121 |
+
ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
|
| 122 |
+
for (int i = 0; i < fixed_gru_h.size(); ++i) {
|
| 123 |
+
EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
|
| 124 |
+
<< "i=" << i;
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
TEST(GruGates, FixedWaveRNNFineMatchesFloat) {
|
| 129 |
+
using GRUMatMulOutType = csrblocksparse::fixed32<11>;
|
| 130 |
+
using GRUStateType = csrblocksparse::fixed16<2>;
|
| 131 |
+
using SampleType = csrblocksparse::fixed16<0>;
|
| 132 |
+
csrblocksparse::CacheAlignedVector<float> float_gru_h =
|
| 133 |
+
TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
|
| 134 |
+
/*kSplitGates=*/true>();
|
| 135 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
|
| 136 |
+
TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
|
| 137 |
+
ARInputsMode::k3ARInputs, /*kSplitGates=*/true>();
|
| 138 |
+
|
| 139 |
+
ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
|
| 140 |
+
for (int i = 0; i < fixed_gru_h.size(); ++i) {
|
| 141 |
+
EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
|
| 142 |
+
<< "i=" << i;
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) {
|
| 147 |
+
using GRUMatMulOutType = csrblocksparse::fixed32<11>;
|
| 148 |
+
using GRUStateType = csrblocksparse::fixed16<2>;
|
| 149 |
+
using SampleType = csrblocksparse::fixed16<0>;
|
| 150 |
+
csrblocksparse::CacheAlignedVector<float> float_gru_h =
|
| 151 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
| 152 |
+
/*kSplitGates=*/false>();
|
| 153 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
|
| 154 |
+
TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
|
| 155 |
+
ARInputsMode::k2ARInputs, /*kSplitGates=*/false>();
|
| 156 |
+
|
| 157 |
+
ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
|
| 158 |
+
for (int i = 0; i < fixed_gru_h.size(); ++i) {
|
| 159 |
+
EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
|
| 160 |
+
<< "i=" << i;
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
} // namespace
|
sparse_matmul/compute/kernels_arm.h
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sparse_matmul/compute/kernels_avx.h
ADDED
|
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
|
| 19 |
+
|
| 20 |
+
#if defined __AVX__
|
| 21 |
+
#include <immintrin.h>
|
| 22 |
+
|
| 23 |
+
#include <algorithm>
|
| 24 |
+
#include <type_traits>
|
| 25 |
+
// TODO(b/188702959): Remove fast_transcendentals with GRU refactor.
|
| 26 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
| 27 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
| 28 |
+
#include "sparse_matmul/numerics/float16_types.h"
|
| 29 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
| 30 |
+
|
| 31 |
+
namespace csrblocksparse {
|
| 32 |
+
namespace detail {
|
| 33 |
+
|
| 34 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 35 |
+
struct IsAllowableFloatTypes
|
| 36 |
+
: std::integral_constant<bool, std::is_same<WeightType, float>::value &&
|
| 37 |
+
std::is_same<RhsType, float>::value &&
|
| 38 |
+
std::is_same<OutType, float>::value> {};
|
| 39 |
+
|
| 40 |
+
#if defined __AVX2__
|
| 41 |
+
// 16-bit inputs, 32-bit output exponent matches sum of input exponents
|
| 42 |
+
// OR
|
| 43 |
+
// 16-bit inputs, 16-bit output - will shift to match exponent
|
| 44 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 45 |
+
struct IsAllowableFixedTypes
|
| 46 |
+
: std::integral_constant<bool, (IsFixed16Type<WeightType>::value &&
|
| 47 |
+
IsFixed16Type<RhsType>::value) &&
|
| 48 |
+
(IsFixed32Type<OutType>::value ||
|
| 49 |
+
IsFixed16Type<OutType>::value)> {};
|
| 50 |
+
|
| 51 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 52 |
+
struct ShouldEnableGenericKernel
|
| 53 |
+
: std::integral_constant<
|
| 54 |
+
bool,
|
| 55 |
+
!IsAllowableFloatTypes<WeightType, RhsType, OutType>::value &&
|
| 56 |
+
!IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {};
|
| 57 |
+
|
| 58 |
+
template <typename Type>
|
| 59 |
+
struct IsAddableFixedTypes
|
| 60 |
+
: std::integral_constant<bool, IsFixed32Type<Type>::value ||
|
| 61 |
+
IsFixed16Type<Type>::value> {};
|
| 62 |
+
template <typename Type>
|
| 63 |
+
struct ShouldEnableGenericAdd
|
| 64 |
+
: std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {};
|
| 65 |
+
|
| 66 |
+
#else // No AVX2.
|
| 67 |
+
|
| 68 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 69 |
+
struct ShouldEnableGenericKernel
|
| 70 |
+
: std::integral_constant<
|
| 71 |
+
bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {};
|
| 72 |
+
|
| 73 |
+
template <typename Type>
|
| 74 |
+
struct ShouldEnableGenericAdd : std::true_type {};
|
| 75 |
+
#endif // __AVX2__
|
| 76 |
+
|
| 77 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 78 |
+
struct ShouldEnableGenericSpMV_4x4
|
| 79 |
+
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
|
| 80 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 81 |
+
struct ShouldEnableGenericSpMM5_4x4
|
| 82 |
+
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
|
| 83 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 84 |
+
struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
|
| 85 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 86 |
+
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
|
| 87 |
+
|
| 88 |
+
// The computational routines do NO error checking for speed. It is assumed
|
| 89 |
+
// that this has been handled by CSRBlockSparseMatrix.
|
| 90 |
+
|
| 91 |
+
// In-line function to extract results from a pair of registers and store in
|
| 92 |
+
// memory. Note that the non-const references are registers, and are modified
|
| 93 |
+
// by this function!
|
| 94 |
+
inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2,
|
| 95 |
+
float** out_ptr) {
|
| 96 |
+
// Horizontally add the results. We have 2 registers, |sum1| and |sum2| that
|
| 97 |
+
// each contain 2 sets of 4 values that need to be added.
|
| 98 |
+
sum1 = _mm256_hadd_ps(sum1, sum2);
|
| 99 |
+
sum1 = _mm256_hadd_ps(sum1, sum1);
|
| 100 |
+
// Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|,
|
| 101 |
+
// |res1|, |res3|]
|
| 102 |
+
if (relu) {
|
| 103 |
+
sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps());
|
| 104 |
+
}
|
| 105 |
+
// It is really hard in AVX to cross the 128 bit 'lanes' and this is the
|
| 106 |
+
// *only* way to do it.
|
| 107 |
+
// Get the top half of |sum1| in to bottom of |sum2|.
|
| 108 |
+
sum2 = _mm256_permute2f128_ps(sum1, sum1, 1);
|
| 109 |
+
// Interleave the values between the two registers.
|
| 110 |
+
sum1 = _mm256_unpacklo_ps(sum1, sum2);
|
| 111 |
+
// Save the lower 128 bits (4 floats).
|
| 112 |
+
__m128 result = _mm256_extractf128_ps(sum1, 0);
|
| 113 |
+
_mm_store_ps(*out_ptr, result);
|
| 114 |
+
*out_ptr += 4;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
| 118 |
+
// blocked pattern, x is a vector and b is vector. Weights are stored for this
|
| 119 |
+
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
|
| 120 |
+
// row-major format. column indices are converted to deltas and then multiplied
|
| 121 |
+
// by 2 to convert to bytes, so that the value can be used directly to offset
|
| 122 |
+
// the pointer into the rhs vector.
|
| 123 |
+
//
|
| 124 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 125 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 126 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 127 |
+
// speedup by reducing latencies at the end of the loop.
|
| 128 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 129 |
+
typename std::enable_if<std::is_same<WeightType, float>::value &&
|
| 130 |
+
std::is_same<RhsType, float>::value &&
|
| 131 |
+
std::is_same<OutType, float>::value>::type
|
| 132 |
+
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 133 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 134 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 135 |
+
OutType* out_ptr, int64_t assigned_rows,
|
| 136 |
+
int64_t rows /* only used in SpMM variants */,
|
| 137 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
| 138 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
| 139 |
+
// Broadcast the biases by 4 to undo the division by 4 in the input biases.
|
| 140 |
+
__m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
| 141 |
+
_mm_broadcast_ss(bias_ptr));
|
| 142 |
+
bias_ptr += 2;
|
| 143 |
+
__m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
| 144 |
+
_mm_broadcast_ss(bias_ptr));
|
| 145 |
+
bias_ptr += 2;
|
| 146 |
+
|
| 147 |
+
int reduced_col_count = *nnz_per_row++;
|
| 148 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
| 149 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 150 |
+
rhs_ptr += col_delta;
|
| 151 |
+
// Multiply this 4x4 block.
|
| 152 |
+
__m256 rhs =
|
| 153 |
+
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
|
| 154 |
+
__m256 weights1 = _mm256_load_ps(weights_ptr);
|
| 155 |
+
weights_ptr += 8;
|
| 156 |
+
sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs));
|
| 157 |
+
__m256 weights2 = _mm256_load_ps(weights_ptr);
|
| 158 |
+
weights_ptr += 8;
|
| 159 |
+
sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs));
|
| 160 |
+
}
|
| 161 |
+
Extract4Results(relu, sum1, sum2, &out_ptr);
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
| 166 |
+
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
|
| 167 |
+
// broadcast. Weights are stored for this routine by making each 4x4 block
|
| 168 |
+
// contiguous. Blocks are ordered in standard row-major format. column indices
|
| 169 |
+
// are converted to deltas and then multiplied by 2 to convert to bytes, so
|
| 170 |
+
// that the value can be used directly to offset the pointer into the rhs
|
| 171 |
+
// vector.
|
| 172 |
+
//
|
| 173 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 174 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 175 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 176 |
+
// speedup by reducing latencies at the end of the loop.
|
| 177 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 178 |
+
typename std::enable_if<std::is_same<WeightType, float>::value &&
|
| 179 |
+
std::is_same<RhsType, float>::value &&
|
| 180 |
+
std::is_same<OutType, float>::value>::type
|
| 181 |
+
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 182 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 183 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 184 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
| 185 |
+
int relu) {
|
| 186 |
+
const RhsType* rhs_ptrs[5];
|
| 187 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
| 188 |
+
|
| 189 |
+
OutType* out_ptrs[5];
|
| 190 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
| 191 |
+
|
| 192 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
| 193 |
+
// We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|.
|
| 194 |
+
// Broadcast the biases by 4 to undo the division by 4 in the input biases.
|
| 195 |
+
__m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
| 196 |
+
_mm_broadcast_ss(bias_ptr));
|
| 197 |
+
bias_ptr += 2;
|
| 198 |
+
__m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
| 199 |
+
_mm_broadcast_ss(bias_ptr));
|
| 200 |
+
bias_ptr += 2;
|
| 201 |
+
__m256 sum1_1 = sum1_0;
|
| 202 |
+
__m256 sum2_1 = sum2_0;
|
| 203 |
+
__m256 sum1_2 = sum1_0;
|
| 204 |
+
__m256 sum2_2 = sum2_0;
|
| 205 |
+
__m256 sum1_3 = sum1_0;
|
| 206 |
+
__m256 sum2_3 = sum2_0;
|
| 207 |
+
__m256 sum1_4 = sum1_0;
|
| 208 |
+
__m256 sum2_4 = sum2_0;
|
| 209 |
+
|
| 210 |
+
int reduced_col_count = *nnz_per_row++;
|
| 211 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
| 212 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 213 |
+
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
|
| 214 |
+
|
| 215 |
+
// Multiply this 4x4 block.
|
| 216 |
+
__m256 rhs =
|
| 217 |
+
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0]));
|
| 218 |
+
__m256 weights1 = _mm256_load_ps(weights_ptr);
|
| 219 |
+
weights_ptr += 8;
|
| 220 |
+
sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs));
|
| 221 |
+
__m256 weights2 = _mm256_load_ps(weights_ptr);
|
| 222 |
+
weights_ptr += 8;
|
| 223 |
+
sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs));
|
| 224 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1]));
|
| 225 |
+
sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs));
|
| 226 |
+
sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs));
|
| 227 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2]));
|
| 228 |
+
sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs));
|
| 229 |
+
sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs));
|
| 230 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3]));
|
| 231 |
+
sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs));
|
| 232 |
+
sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs));
|
| 233 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4]));
|
| 234 |
+
sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs));
|
| 235 |
+
sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs));
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]);
|
| 239 |
+
Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]);
|
| 240 |
+
Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]);
|
| 241 |
+
Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]);
|
| 242 |
+
Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]);
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
#ifdef __AVX2__
|
| 247 |
+
|
| 248 |
+
// In-line function to finish the computation of the result as 4x int32 in
|
| 249 |
+
// |sum|.
|
| 250 |
+
inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) {
|
| 251 |
+
// Horizontally add the results. We have 1 register that contains results
|
| 252 |
+
// [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
|
| 253 |
+
// cross lanes, so we end up with [0 1 0 1 2 3 2 3]
|
| 254 |
+
sum = _mm256_hadd_epi32(sum, sum);
|
| 255 |
+
// Permutes the middle two pairs to get the answers together.
|
| 256 |
+
sum = _mm256_permute4x64_epi64(sum, 0xd8);
|
| 257 |
+
if (kShiftAmount > 0) {
|
| 258 |
+
// Shift right with rounding to get the right number of mantissa bits.
|
| 259 |
+
__m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1));
|
| 260 |
+
sum = _mm256_add_epi32(sum, rounding);
|
| 261 |
+
sum = _mm256_srai_epi32(sum, kShiftAmount);
|
| 262 |
+
}
|
| 263 |
+
// Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|,
|
| 264 |
+
// |res2|, |res3|]
|
| 265 |
+
if (relu) {
|
| 266 |
+
sum = _mm256_max_epi32(sum, _mm256_setzero_si256());
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
// In-line function to extract the 4x int32 results from |sum| to memory.
|
| 271 |
+
// Non-const reference for |sum| as it is a register.
|
| 272 |
+
inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum,
|
| 273 |
+
int32_t** out_ptr) {
|
| 274 |
+
Compute4Results(relu, kShiftAmount, sum);
|
| 275 |
+
// Save the lower 128 bits (4x int32).
|
| 276 |
+
__m128i result = _mm256_extractf128_si256(sum, 0);
|
| 277 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result);
|
| 278 |
+
*out_ptr += 4;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// In-line function to extract the 4x int32 results from sum to 4x int16 in
|
| 282 |
+
// memory.
|
| 283 |
+
// Non-const reference for |sum| as it is a register.
|
| 284 |
+
inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum,
|
| 285 |
+
int16_t** out_ptr) {
|
| 286 |
+
Compute4Results(relu, kShiftAmount, sum);
|
| 287 |
+
// Clip to 16 bit range (with saturation) and pack in the bottom 64 bits.
|
| 288 |
+
// Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64
|
| 289 |
+
// bits, replicated in the next 64 bits.
|
| 290 |
+
sum = _mm256_packs_epi32(sum, sum);
|
| 291 |
+
// Save 4x int 16 from the bottom 64 bits.
|
| 292 |
+
*reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0);
|
| 293 |
+
*out_ptr += 4;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
| 297 |
+
// blocked pattern, x is a vector and b is vector. Weights are stored for this
|
| 298 |
+
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
|
| 299 |
+
// row-major format. column indices are converted to deltas and then multiplied
|
| 300 |
+
// by 2 to convert to bytes, so that the value can be used directly to offset
|
| 301 |
+
// the pointer into the rhs vector.
|
| 302 |
+
//
|
| 303 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 304 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 305 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 306 |
+
// speedup by reducing latencies at the end of the loop.
|
| 307 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 308 |
+
typename std::enable_if<
|
| 309 |
+
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
|
| 310 |
+
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
|
| 311 |
+
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 312 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 313 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 314 |
+
OutType* out_ptr, int64_t assigned_rows,
|
| 315 |
+
int64_t rows /* only used in SpMM variants */,
|
| 316 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
| 317 |
+
constexpr int kShiftAmount =
|
| 318 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
| 319 |
+
OutType::kMantissaBits;
|
| 320 |
+
static_assert(kShiftAmount >= 0,
|
| 321 |
+
"Result must have fewer mantissa bits than product");
|
| 322 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
| 323 |
+
// Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
|
| 324 |
+
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
|
| 325 |
+
__m256i biases = _mm256_set_m128i(bias, bias);
|
| 326 |
+
bias_ptr += 4;
|
| 327 |
+
// Swap the top two pairs: [0 1 2 3 2 3 0 1]
|
| 328 |
+
// TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index
|
| 329 |
+
// register outside the row loop.
|
| 330 |
+
biases = _mm256_permute4x64_epi64(biases, 0xb4);
|
| 331 |
+
// Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
|
| 332 |
+
biases = _mm256_unpacklo_epi32(biases, biases);
|
| 333 |
+
// Double the results to make up for the division by 4.
|
| 334 |
+
// TODO(b/188702959): consider moving this to where the biases are computed.
|
| 335 |
+
__m256i sum = _mm256_add_epi32(biases, biases);
|
| 336 |
+
|
| 337 |
+
// TODO(b/188702959): People don't like the old-fashioned, close-to-the-
|
| 338 |
+
// metal notation of *|nnz_per_row|++, so measure the effect of putting the
|
| 339 |
+
// increment in the for loop.
|
| 340 |
+
int reduced_col_count = *nnz_per_row;
|
| 341 |
+
++nnz_per_row;
|
| 342 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
| 343 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 344 |
+
rhs_ptr += col_delta;
|
| 345 |
+
// Multiply this 4x4 block.
|
| 346 |
+
// Get the 4x int16 into the bottom of rhs_64.
|
| 347 |
+
__m128i rhs_64 =
|
| 348 |
+
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr));
|
| 349 |
+
// Load all 16 weights.
|
| 350 |
+
__m256i weights =
|
| 351 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
| 352 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
| 353 |
+
// [0123 0123 0123 0123].
|
| 354 |
+
__m256i rhs = _mm256_broadcastq_epi64(rhs_64);
|
| 355 |
+
weights_ptr += 16;
|
| 356 |
+
// |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
|
| 357 |
+
// adds adjacent pairs to make 8x32 bit results. Add these to the sum.
|
| 358 |
+
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs));
|
| 359 |
+
}
|
| 360 |
+
static_assert(
|
| 361 |
+
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
|
| 362 |
+
"AVX2 kernel only supports fixed16 and fixed32 types");
|
| 363 |
+
// The only significant difference between fixed16 and fixed32 is the size
|
| 364 |
+
// of the storage unit. The registers have to be repacked accordingly.
|
| 365 |
+
if (IsFixed32Type<OutType>::value) {
|
| 366 |
+
Extract4xint32(relu, kShiftAmount, sum,
|
| 367 |
+
reinterpret_cast<int32_t**>(&out_ptr));
|
| 368 |
+
} else {
|
| 369 |
+
Extract4xint16(relu, kShiftAmount, sum,
|
| 370 |
+
reinterpret_cast<int16_t**>(&out_ptr));
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
| 376 |
+
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
|
| 377 |
+
// broadcast. Weights are stored for this routine by making each 4x4 block
|
| 378 |
+
// contiguous. Blocks are ordered in standard row-major format. column indices
|
| 379 |
+
// are converted to deltas and then multiplied by 2 to convert to bytes, so
|
| 380 |
+
// that the value can be used directly to offset the pointer into the rhs
|
| 381 |
+
// vector.
|
| 382 |
+
//
|
| 383 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 384 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 385 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 386 |
+
// speedup by reducing latencies at the end of the loop.
|
| 387 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 388 |
+
typename std::enable_if<
|
| 389 |
+
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
|
| 390 |
+
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
|
| 391 |
+
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 392 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 393 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 394 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
| 395 |
+
int relu) {
|
| 396 |
+
constexpr int kShiftAmount =
|
| 397 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
| 398 |
+
OutType::kMantissaBits;
|
| 399 |
+
static_assert(kShiftAmount >= 0,
|
| 400 |
+
"Result must have fewer mantissa bits than product");
|
| 401 |
+
const RhsType* rhs_ptrs[5];
|
| 402 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
| 403 |
+
|
| 404 |
+
OutType* out_ptrs[5];
|
| 405 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
| 406 |
+
|
| 407 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
| 408 |
+
// We will acumulate the results in 5 registers, sum_0 to sum_4.
|
| 409 |
+
// Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
|
| 410 |
+
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
|
| 411 |
+
__m256i biases = _mm256_set_m128i(bias, bias);
|
| 412 |
+
bias_ptr += 4;
|
| 413 |
+
// Swap the top two pairs: [0 1 2 3 2 3 0 1]
|
| 414 |
+
biases = _mm256_permute4x64_epi64(biases, 0xb4);
|
| 415 |
+
// Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
|
| 416 |
+
biases = _mm256_unpacklo_epi32(biases, biases);
|
| 417 |
+
// Double the results to make up for the division by 4.
|
| 418 |
+
__m256i sum_0 = _mm256_add_epi32(biases, biases);
|
| 419 |
+
__m256i sum_1 = sum_0;
|
| 420 |
+
__m256i sum_2 = sum_0;
|
| 421 |
+
__m256i sum_3 = sum_0;
|
| 422 |
+
__m256i sum_4 = sum_0;
|
| 423 |
+
|
| 424 |
+
int reduced_col_count = *nnz_per_row;
|
| 425 |
+
++nnz_per_row;
|
| 426 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
| 427 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 428 |
+
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
|
| 429 |
+
// Multiply this 4x4 block.
|
| 430 |
+
// Get the 4x int16 into the bottom of |rhs_64|.
|
| 431 |
+
__m128i rhs_64 =
|
| 432 |
+
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0]));
|
| 433 |
+
// Load all 16 weights.
|
| 434 |
+
__m256i weights =
|
| 435 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
| 436 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
| 437 |
+
// [0123 0123 0123 0123].
|
| 438 |
+
__m256i rhs = _mm256_broadcastq_epi64(rhs_64);
|
| 439 |
+
weights_ptr += 16;
|
| 440 |
+
// |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
|
| 441 |
+
// adds adjacent pairs to make 8x32 bit results. Add these to the sum.
|
| 442 |
+
sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs));
|
| 443 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1]));
|
| 444 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
| 445 |
+
sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs));
|
| 446 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2]));
|
| 447 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
| 448 |
+
sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs));
|
| 449 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3]));
|
| 450 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
| 451 |
+
sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs));
|
| 452 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4]));
|
| 453 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
| 454 |
+
sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs));
|
| 455 |
+
}
|
| 456 |
+
static_assert(
|
| 457 |
+
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
|
| 458 |
+
"AVX2 kernel only supports fixed16 and fixed32 types");
|
| 459 |
+
// The only significant difference between fixed16 and fixed32 is the size
|
| 460 |
+
// of the storage unit. The registers have to be repacked accordingly.
|
| 461 |
+
if (IsFixed32Type<OutType>::value) {
|
| 462 |
+
Extract4xint32(relu, kShiftAmount, sum_0,
|
| 463 |
+
reinterpret_cast<int32_t**>(&out_ptrs[0]));
|
| 464 |
+
Extract4xint32(relu, kShiftAmount, sum_1,
|
| 465 |
+
reinterpret_cast<int32_t**>(&out_ptrs[1]));
|
| 466 |
+
Extract4xint32(relu, kShiftAmount, sum_2,
|
| 467 |
+
reinterpret_cast<int32_t**>(&out_ptrs[2]));
|
| 468 |
+
Extract4xint32(relu, kShiftAmount, sum_3,
|
| 469 |
+
reinterpret_cast<int32_t**>(&out_ptrs[3]));
|
| 470 |
+
Extract4xint32(relu, kShiftAmount, sum_4,
|
| 471 |
+
reinterpret_cast<int32_t**>(&out_ptrs[4]));
|
| 472 |
+
} else {
|
| 473 |
+
Extract4xint16(relu, kShiftAmount, sum_0,
|
| 474 |
+
reinterpret_cast<int16_t**>(&out_ptrs[0]));
|
| 475 |
+
Extract4xint16(relu, kShiftAmount, sum_1,
|
| 476 |
+
reinterpret_cast<int16_t**>(&out_ptrs[1]));
|
| 477 |
+
Extract4xint16(relu, kShiftAmount, sum_2,
|
| 478 |
+
reinterpret_cast<int16_t**>(&out_ptrs[2]));
|
| 479 |
+
Extract4xint16(relu, kShiftAmount, sum_3,
|
| 480 |
+
reinterpret_cast<int16_t**>(&out_ptrs[3]));
|
| 481 |
+
Extract4xint16(relu, kShiftAmount, sum_4,
|
| 482 |
+
reinterpret_cast<int16_t**>(&out_ptrs[4]));
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
// Processes one GRU gate input with sigmoid.
|
| 488 |
+
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates>
|
| 489 |
+
inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr,
|
| 490 |
+
const __m256i& input,
|
| 491 |
+
const int32_t* sigmoid_table) {
|
| 492 |
+
__m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr));
|
| 493 |
+
if (SplitGates) {
|
| 494 |
+
__m256i other =
|
| 495 |
+
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr));
|
| 496 |
+
gate = _mm256_add_epi32(gate, other);
|
| 497 |
+
}
|
| 498 |
+
gate = _mm256_add_epi32(gate, input);
|
| 499 |
+
// Compute sigmoids on reset and update.
|
| 500 |
+
return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits,
|
| 501 |
+
StateMantissaBits>(
|
| 502 |
+
sigmoid_table, gate);
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
// Processes the tanh and the final combination, returning the new GRU state.
|
| 506 |
+
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false>
|
| 507 |
+
inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset,
|
| 508 |
+
const __m256i& update,
|
| 509 |
+
const __m256i& rounding_offset,
|
| 510 |
+
const void* gate_ptr, const void* gate_other_ptr,
|
| 511 |
+
const void* gru_h_ptr, const int32_t* tanh_table) {
|
| 512 |
+
// Multiply the cell GRU output and the reset. There is a slight danger of
|
| 513 |
+
// loss of precision here, so use 32x32=64 bit and shift back after.
|
| 514 |
+
__m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr));
|
| 515 |
+
if (SplitGates) {
|
| 516 |
+
__m256i other_gru =
|
| 517 |
+
_mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr));
|
| 518 |
+
gru = _mm256_add_epi32(gru, other_gru);
|
| 519 |
+
}
|
| 520 |
+
// This only computes the products of the low-order 32 bits of each pair.
|
| 521 |
+
__m256i gru_lo = _mm256_mul_epi32(gru, reset);
|
| 522 |
+
// Swap odd and even 32-bit units and do it again to get the high products.
|
| 523 |
+
gru = _mm256_shuffle_epi32(gru, 0xb1);
|
| 524 |
+
__m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1));
|
| 525 |
+
// Now shift right to compensate for the multiply and re-interleave the
|
| 526 |
+
// 32-bit results.
|
| 527 |
+
// NOTE: There is no shift right arithmetic for 64 bit values until AVX512!
|
| 528 |
+
// Fortunately it doesn't matter, as the results are being truncated to 32
|
| 529 |
+
// bits and we aren't shifting right by more than 32 bits here.
|
| 530 |
+
gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits);
|
| 531 |
+
// The upper results are shifted LEFT, so we can use blend to recombine in
|
| 532 |
+
// a single instruction.
|
| 533 |
+
gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits);
|
| 534 |
+
// Recombine the 32 bit results from lo and hi, alternating.
|
| 535 |
+
gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa);
|
| 536 |
+
gru = _mm256_add_epi32(cell, gru);
|
| 537 |
+
// Compute tanh on the result. Although this instantly discards a bunch of
|
| 538 |
+
// bits, there were only 7 surplus bits for the multiply, which isn't enough
|
| 539 |
+
// to do it as 16x16=32.
|
| 540 |
+
__m256i hbar =
|
| 541 |
+
csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits,
|
| 542 |
+
StateMantissaBits>(tanh_table, gru);
|
| 543 |
+
// Load the 16-bit previous GRU state and sign-extend to 32 bits.
|
| 544 |
+
gru = _mm256_cvtepi16_epi32(
|
| 545 |
+
_mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr)));
|
| 546 |
+
gru = _mm256_sub_epi32(gru, hbar);
|
| 547 |
+
// Since |gru| is 16 bit sign-extended to 32, and |update| is the output of
|
| 548 |
+
// sigmoid, it is always contained within 16 bits and never negative, we can
|
| 549 |
+
// use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the
|
| 550 |
+
// addend will always be zero, and this is twice as fast as full blown
|
| 551 |
+
// 32x32=32. The only possible problem is if the subtract above caused
|
| 552 |
+
// overflow.
|
| 553 |
+
gru = _mm256_madd_epi16(gru, update);
|
| 554 |
+
// Renormalize to fixed16. This time rounding is critical, as this is the
|
| 555 |
+
// output GRU state.
|
| 556 |
+
gru = _mm256_add_epi32(gru, rounding_offset);
|
| 557 |
+
gru = _mm256_srai_epi32(gru, StateMantissaBits);
|
| 558 |
+
return _mm256_add_epi32(gru, hbar);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
template <typename Type>
|
| 562 |
+
typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors(
|
| 563 |
+
int start, int end, const Type* add1, const Type* add2, Type* result) {
|
| 564 |
+
constexpr int kSIMDWidth = 8;
|
| 565 |
+
for (int i = start; i < end; i += kSIMDWidth) {
|
| 566 |
+
__m256i data1 =
|
| 567 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
|
| 568 |
+
__m256i data2 =
|
| 569 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
|
| 570 |
+
data1 = _mm256_add_epi32(data1, data2);
|
| 571 |
+
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
template <typename Type>
|
| 576 |
+
typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors(
|
| 577 |
+
int start, int end, const Type* add1, const Type* add2, Type* result) {
|
| 578 |
+
constexpr int kSIMDWidth = 16;
|
| 579 |
+
for (int i = start; i < end; i += kSIMDWidth) {
|
| 580 |
+
__m256i data1 =
|
| 581 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
|
| 582 |
+
__m256i data2 =
|
| 583 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
|
| 584 |
+
data1 = _mm256_add_epi16(data1, data2);
|
| 585 |
+
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
#endif // __AVX2__
|
| 590 |
+
|
| 591 |
+
} // namespace detail
|
| 592 |
+
} // namespace csrblocksparse
|
| 593 |
+
|
| 594 |
+
#undef LABEL_COL_LOOP
|
| 595 |
+
#undef LABEL_ROW_LOOP
|
| 596 |
+
#undef LABEL_SKIP_COL_LOOP
|
| 597 |
+
#undef LABEL_TOP_LOOP
|
| 598 |
+
|
| 599 |
+
#endif // __AVX__
|
| 600 |
+
|
| 601 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
|
sparse_matmul/compute/kernels_generic.h
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
|
| 19 |
+
|
| 20 |
+
#include <algorithm>
|
| 21 |
+
#include <type_traits>
|
| 22 |
+
|
| 23 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
| 24 |
+
#include "sparse_matmul/numerics/float16_types.h"
|
| 25 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
| 26 |
+
|
| 27 |
+
// Separate out the assembly kernels for readability. Eventually this will
|
| 28 |
+
// become an ifdef switch on the architecture type.
|
| 29 |
+
#if defined __aarch64__
|
| 30 |
+
#include "sparse_matmul/compute/kernels_arm.h"
|
| 31 |
+
#elif defined __AVX__
|
| 32 |
+
#include "sparse_matmul/compute/kernels_avx.h"
|
| 33 |
+
#else // defined __AVX__
|
| 34 |
+
// If there is no architecture-specific implementation, then always use generic.
|
| 35 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 36 |
+
struct ShouldEnableGenericSpMV_4x4 : std::true_type {};
|
| 37 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 38 |
+
struct ShouldEnableGenericSpMM5_4x4 : std::true_type {};
|
| 39 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 40 |
+
struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
|
| 41 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 42 |
+
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
|
| 43 |
+
template <typename Type>
|
| 44 |
+
struct ShouldEnableGenericAdd : std::true_type {};
|
| 45 |
+
#endif // defined __arch64__
|
| 46 |
+
|
| 47 |
+
namespace csrblocksparse {
|
| 48 |
+
namespace detail {
|
| 49 |
+
|
| 50 |
+
// The computational routines do NO error checking for speed. It is assumed
|
| 51 |
+
// that this has been handled by CSRBlockSparseMatrix.
|
| 52 |
+
|
| 53 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
| 54 |
+
// blocked pattern, x is a vector and b is vector. Weights are stored for this
|
| 55 |
+
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
|
| 56 |
+
// row-major format. column indices are converted to deltas and then multiplied
|
| 57 |
+
// by 2 to convert to bytes, so that the value can be used directly to offset
|
| 58 |
+
// the pointer into the rhs vector.
|
| 59 |
+
//
|
| 60 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 61 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 62 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 63 |
+
// speedup by reducing latencies at the end of the loop.
|
| 64 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 65 |
+
typename std::enable_if<
|
| 66 |
+
ShouldEnableGenericSpMV_4x4<WeightType, RhsType, OutType>::value>::type
|
| 67 |
+
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 68 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 69 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 70 |
+
OutType* out_ptr, int64_t assigned_rows,
|
| 71 |
+
int64_t rows /* only used in SpMM variants */,
|
| 72 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
| 73 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
| 74 |
+
float accumulators[4];
|
| 75 |
+
// Undo the divion by the happens for the assembly version.
|
| 76 |
+
for (int i = 0; i < 4; ++i)
|
| 77 |
+
accumulators[i] = 4.f * static_cast<float>(*bias_ptr++);
|
| 78 |
+
|
| 79 |
+
int reduced_col_count = *nnz_per_row++;
|
| 80 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
| 81 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 82 |
+
rhs_ptr += col_delta;
|
| 83 |
+
|
| 84 |
+
// Multiply this 4x4 block.
|
| 85 |
+
for (int i = 0; i < 4; ++i) {
|
| 86 |
+
for (int j = 0; j < 4; ++j) {
|
| 87 |
+
accumulators[i] += static_cast<float>(*weights_ptr++) *
|
| 88 |
+
static_cast<float>(rhs_ptr[j]);
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
for (int i = 0; i < 4; ++i)
|
| 94 |
+
*out_ptr++ = static_cast<OutType>(relu ? std::max(accumulators[i], 0.f)
|
| 95 |
+
: accumulators[i]);
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
| 100 |
+
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
|
| 101 |
+
// broadcast. Weights are stored for this routine by making each 4x4 block
|
| 102 |
+
// contiguous. Blocks are ordered in standard row-major format. column indices
|
| 103 |
+
// are converted to deltas and then multiplied by 2 to convert to bytes, so
|
| 104 |
+
// that the value can be used directly to offset the pointer into the rhs
|
| 105 |
+
// vector.
|
| 106 |
+
//
|
| 107 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 108 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 109 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 110 |
+
// speedup by reducing latencies at the end of the loop.
|
| 111 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 112 |
+
typename std::enable_if<
|
| 113 |
+
ShouldEnableGenericSpMM5_4x4<WeightType, RhsType, OutType>::value>::type
|
| 114 |
+
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 115 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 116 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 117 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
| 118 |
+
int relu) {
|
| 119 |
+
const RhsType* rhs_ptrs[5];
|
| 120 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
| 121 |
+
|
| 122 |
+
OutType* out_ptrs[5];
|
| 123 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
| 124 |
+
|
| 125 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
| 126 |
+
float accumulators[4][5];
|
| 127 |
+
// Undo the divion by the happens for the assembly version.
|
| 128 |
+
for (int i = 0; i < 4; ++i) {
|
| 129 |
+
for (int k = 0; k < 5; ++k) {
|
| 130 |
+
accumulators[i][k] = 4.f * static_cast<float>(*bias_ptr);
|
| 131 |
+
}
|
| 132 |
+
++bias_ptr;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
int reduced_col_count = *nnz_per_row++;
|
| 136 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
| 137 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 138 |
+
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
|
| 139 |
+
|
| 140 |
+
// multiply this 4x4 block
|
| 141 |
+
for (int i = 0; i < 4; ++i) {
|
| 142 |
+
for (int j = 0; j < 4; ++j) {
|
| 143 |
+
for (int k = 0; k < 5; ++k) {
|
| 144 |
+
accumulators[i][k] += static_cast<float>(*weights_ptr) *
|
| 145 |
+
static_cast<float>(rhs_ptrs[k][j]);
|
| 146 |
+
}
|
| 147 |
+
weights_ptr++;
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
for (int k = 0; k < 5; ++k) {
|
| 153 |
+
for (int i = 0; i < 4; ++i) {
|
| 154 |
+
out_ptrs[k][0] = static_cast<OutType>(
|
| 155 |
+
relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]);
|
| 156 |
+
out_ptrs[k]++;
|
| 157 |
+
}
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with
|
| 163 |
+
// a 1x1 blocked pattern (ie unstructured), x is a
|
| 164 |
+
// vector and b is vector.
|
| 165 |
+
// Weights are stored for this routine in standard CSR format. Each row must
|
| 166 |
+
// have a multiple of 8 columns.
|
| 167 |
+
// column indices are converted to deltas and then multiplied by 2 to convert
|
| 168 |
+
// to bytes, so that the value can be used directly to offset the pointer
|
| 169 |
+
// into the rhs vector.
|
| 170 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 171 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 172 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 173 |
+
// speedup by reducing latencies at the end of the loop.
|
| 174 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 175 |
+
typename std::enable_if<
|
| 176 |
+
ShouldEnableGenericSpMV_1x1<WeightType, RhsType, OutType>::value>::type
|
| 177 |
+
SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 178 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 179 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 180 |
+
OutType* out_ptr, int64_t assigned_rows,
|
| 181 |
+
int64_t rows /* only used in SpMM variants */,
|
| 182 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
| 183 |
+
for (int row = 0; row < assigned_rows; ++row) {
|
| 184 |
+
// Undo the divion by the happens for the assembly version.
|
| 185 |
+
float accumulator = 4.f * static_cast<float>(*bias_ptr++);
|
| 186 |
+
|
| 187 |
+
int col_count = *nnz_per_row++;
|
| 188 |
+
for (int c = 0; c < col_count; ++c) {
|
| 189 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 190 |
+
rhs_ptr += col_delta;
|
| 191 |
+
|
| 192 |
+
accumulator +=
|
| 193 |
+
static_cast<float>(*weights_ptr++) * static_cast<float>(*rhs_ptr);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
*out_ptr++ =
|
| 197 |
+
static_cast<OutType>(relu ? std::max(accumulator, 0.f) : accumulator);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with
|
| 202 |
+
// a 1x1 blocked pattern (ie unstructured), x is a
|
| 203 |
+
// vector and b is vector.
|
| 204 |
+
// Weights are stored for this routine in standard CSR format. Each row must
|
| 205 |
+
// have a multiple of 8 columns.
|
| 206 |
+
// column indices are converted to deltas and then multiplied by 2 to convert
|
| 207 |
+
// to bytes, so that the value can be used directly to offset the pointer
|
| 208 |
+
// into the rhs vector.
|
| 209 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
| 210 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
| 211 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
| 212 |
+
// speedup by reducing latencies at the end of the loop.
|
| 213 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
| 214 |
+
typename std::enable_if<
|
| 215 |
+
ShouldEnableGenericSpMM5_1x1<WeightType, RhsType, OutType>::value>::type
|
| 216 |
+
SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
| 217 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
| 218 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
| 219 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
| 220 |
+
int relu) {
|
| 221 |
+
const RhsType* rhs_ptrs[5];
|
| 222 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
| 223 |
+
|
| 224 |
+
OutType* out_ptrs[5];
|
| 225 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
| 226 |
+
|
| 227 |
+
for (int row = 0; row < assigned_rows; ++row) {
|
| 228 |
+
// Undo the divion by the happens for the assembly version.
|
| 229 |
+
float accumulator[5];
|
| 230 |
+
for (int i = 0; i < 5; ++i)
|
| 231 |
+
accumulator[i] = 4.f * static_cast<float>(*bias_ptr);
|
| 232 |
+
|
| 233 |
+
++bias_ptr;
|
| 234 |
+
|
| 235 |
+
int col_count = *nnz_per_row++;
|
| 236 |
+
for (int c = 0; c < col_count; ++c) {
|
| 237 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
| 238 |
+
for (int i = 0; i < 5; ++i) {
|
| 239 |
+
rhs_ptrs[i] += col_delta;
|
| 240 |
+
accumulator[i] += static_cast<float>(*weights_ptr) *
|
| 241 |
+
static_cast<float>(rhs_ptrs[i][0]);
|
| 242 |
+
}
|
| 243 |
+
weights_ptr++;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
for (int i = 0; i < 5; ++i) {
|
| 247 |
+
out_ptrs[i][0] = static_cast<OutType>(relu ? std::max(accumulator[i], 0.f)
|
| 248 |
+
: accumulator[i]);
|
| 249 |
+
out_ptrs[i]++;
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
template <typename Type>
|
| 255 |
+
typename std::enable_if<ShouldEnableGenericAdd<Type>::value>::type SumVectors(
|
| 256 |
+
int start, int end, const Type* add1, const Type* add2, Type* result) {
|
| 257 |
+
LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!";
|
| 258 |
+
for (int i = start; i < end; ++i) {
|
| 259 |
+
Type sum = static_cast<Type>(static_cast<float>(add1[i]) +
|
| 260 |
+
static_cast<float>(add2[i]));
|
| 261 |
+
result[i] = sum;
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
} // namespace detail
|
| 266 |
+
} // namespace csrblocksparse
|
| 267 |
+
|
| 268 |
+
#undef LABEL_COL_LOOP
|
| 269 |
+
#undef LABEL_ROW_LOOP
|
| 270 |
+
#undef LABEL_SKIP_COL_LOOP
|
| 271 |
+
#undef LABEL_TOP_LOOP
|
| 272 |
+
|
| 273 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
|
sparse_matmul/compute/matmul.h
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
#include <vector>
|
| 22 |
+
|
| 23 |
+
#include "absl/time/time.h"
|
| 24 |
+
#include "sparse_matmul/compute/matmul_fixed_avx2.h"
|
| 25 |
+
#include "sparse_matmul/compute/matmul_generic.h"
|
| 26 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
| 27 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
| 28 |
+
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
|
| 29 |
+
#include <cpuid.h>
|
| 30 |
+
#endif
|
| 31 |
+
|
| 32 |
+
namespace csrblocksparse {
|
| 33 |
+
|
| 34 |
+
// The number of elements in a block.
|
| 35 |
+
constexpr int kBlockSize = 4;
|
| 36 |
+
|
| 37 |
+
// Base class for Matmul containing the members that are non type-specicfic.
|
| 38 |
+
class MatmulBase {
|
| 39 |
+
public:
|
| 40 |
+
// Constructor initializes the flags that determine which implementation to
|
| 41 |
+
// use at run-time, constrained by both compiler flags and cpuid.
|
| 42 |
+
MatmulBase() {
|
| 43 |
+
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
|
| 44 |
+
// Code tested to work on Linux systems and multiple Android emulators.
|
| 45 |
+
unsigned int eax, ebx, ecx, edx;
|
| 46 |
+
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) {
|
| 47 |
+
using_avx_ = (ecx & bit_AVX) != 0;
|
| 48 |
+
if (using_avx_) {
|
| 49 |
+
__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx);
|
| 50 |
+
using_avx2_ = (ebx & bit_AVX2) != 0;
|
| 51 |
+
using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) &&
|
| 52 |
+
(ebx & bit_AVX512BW) != 0;
|
| 53 |
+
VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_;
|
| 54 |
+
} else {
|
| 55 |
+
LOG(ERROR) << "AVX not found at all!";
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
#else
|
| 59 |
+
using_aarch64_ = true;
|
| 60 |
+
#endif
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
protected:
|
| 64 |
+
// Flags that define what (runtime) architectures are available. Flags that
|
| 65 |
+
// are set are limited by both the compiler flags and runtime environment.
|
| 66 |
+
bool using_avx512_ = false;
|
| 67 |
+
bool using_avx2_ = false;
|
| 68 |
+
bool using_avx_ = false;
|
| 69 |
+
bool using_aarch64_ = false;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
// The master template is really a catch-all for the unimplmented cases to
|
| 73 |
+
// report an error.
|
| 74 |
+
template <typename WeightType, typename RhsType>
|
| 75 |
+
class Matmul : public MatmulBase {
|
| 76 |
+
public:
|
| 77 |
+
// Sparse inputs, outputs replicated strided for each thread.
|
| 78 |
+
template <typename OutType>
|
| 79 |
+
void MatVec4x4(const WeightType* weights, const RhsType* rhs,
|
| 80 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias,
|
| 81 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
| 82 |
+
int start_row, int end_row, bool relu, int replicas,
|
| 83 |
+
int stride, OutType* output) {
|
| 84 |
+
// The specializations should take care of every real case.
|
| 85 |
+
CHECK(false) << "Unsupported combination of types used!";
|
| 86 |
+
}
|
| 87 |
+
template <typename OutType>
|
| 88 |
+
void MatVec8x4(const WeightType* weights, const RhsType* rhs,
|
| 89 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias,
|
| 90 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
| 91 |
+
int start_row, int end_row, bool relu, int replicas,
|
| 92 |
+
int stride, OutType* output) {
|
| 93 |
+
// The specializations should take care of every real case.
|
| 94 |
+
CHECK(false) << "Unsupported combination of types used!";
|
| 95 |
+
}
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
// Full specialization for float.
|
| 99 |
+
template <>
|
| 100 |
+
class Matmul<float, float> : public MatmulBase {
|
| 101 |
+
public:
|
| 102 |
+
void MatVec4x4(const float* weights, const float* rhs, const float* bias,
|
| 103 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
| 104 |
+
int start_row, int end_row, bool relu, int replicas,
|
| 105 |
+
int stride, float* output) {
|
| 106 |
+
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 107 |
+
start_row, end_row, /*block_height=*/4,
|
| 108 |
+
/*block_width=*/4, relu, replicas, stride,
|
| 109 |
+
output);
|
| 110 |
+
}
|
| 111 |
+
void MatVec8x4(const float* weights, const float* rhs, const float* bias,
|
| 112 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
| 113 |
+
int start_row, int end_row, bool relu, int replicas,
|
| 114 |
+
int stride, float* output) {
|
| 115 |
+
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 116 |
+
start_row, end_row, /*block_height=*/8,
|
| 117 |
+
/*block_width=*/4, relu, replicas, stride,
|
| 118 |
+
output);
|
| 119 |
+
}
|
| 120 |
+
};
|
| 121 |
+
|
| 122 |
+
// Partial specialization for fixed types. Covers fixed16xfixed16 = OutType,
|
| 123 |
+
// where OutType should be fixed16 or fixed32. The mantissa bits don't have
|
| 124 |
+
// to match.
|
| 125 |
+
template <int WeightBits, int RhsBits>
|
| 126 |
+
class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase {
|
| 127 |
+
public:
|
| 128 |
+
using WeightType = fixed16<WeightBits>;
|
| 129 |
+
using RhsType = fixed16<RhsBits>;
|
| 130 |
+
|
| 131 |
+
template <typename OutType>
|
| 132 |
+
void MatVec4x4(const int16_t* weights, const int16_t* rhs,
|
| 133 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 134 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 135 |
+
bool relu, int replicas, int stride, OutType* output) {
|
| 136 |
+
constexpr int kShiftAmount =
|
| 137 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
| 138 |
+
OutType::kMantissaBits;
|
| 139 |
+
static_assert(kShiftAmount >= 0,
|
| 140 |
+
"OutType must not have more mantissa bits than inputs");
|
| 141 |
+
#if defined __AVX2__
|
| 142 |
+
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
|
| 143 |
+
if (sizeof(*output) == 4) {
|
| 144 |
+
int32_t* out32 = reinterpret_cast<int32_t*>(output);
|
| 145 |
+
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 146 |
+
start_row, end_row, relu, kShiftAmount,
|
| 147 |
+
replicas, stride, out32);
|
| 148 |
+
} else {
|
| 149 |
+
int16_t* out16 = reinterpret_cast<int16_t*>(output);
|
| 150 |
+
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 151 |
+
start_row, end_row, relu, kShiftAmount,
|
| 152 |
+
replicas, stride, out16);
|
| 153 |
+
}
|
| 154 |
+
#elif defined __aarch64__
|
| 155 |
+
if (using_aarch64_) {
|
| 156 |
+
LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!";
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
#else
|
| 160 |
+
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 161 |
+
start_row, end_row, /*block_height=*/4,
|
| 162 |
+
/*block_width=*/4, relu, sizeof(*output),
|
| 163 |
+
kShiftAmount, replicas, stride, output);
|
| 164 |
+
#endif // __AVX2__
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
template <typename OutType>
|
| 168 |
+
void MatVec8x4(const int16_t* weights, const int16_t* rhs,
|
| 169 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 170 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 171 |
+
bool relu, int replicas, int stride, OutType* output) {
|
| 172 |
+
constexpr int kShiftAmount =
|
| 173 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
| 174 |
+
OutType::kMantissaBits;
|
| 175 |
+
static_assert(kShiftAmount >= 0,
|
| 176 |
+
"OutType must not have more mantissa bits than inputs");
|
| 177 |
+
#if defined __AVX2__
|
| 178 |
+
CHECK(replicas == 1 && sizeof(*output) == 4)
|
| 179 |
+
<< "Only replicas == 1 and fixed32 output are implemented for AVX2!";
|
| 180 |
+
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
|
| 181 |
+
int32_t* out32 = reinterpret_cast<int32_t*>(output);
|
| 182 |
+
detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 183 |
+
start_row, end_row, relu, kShiftAmount, out32);
|
| 184 |
+
#elif defined __aarch64__
|
| 185 |
+
if (using_aarch64_) {
|
| 186 |
+
LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!";
|
| 187 |
+
}
|
| 188 |
+
#else
|
| 189 |
+
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
| 190 |
+
start_row, end_row, /*block_height=*/8,
|
| 191 |
+
/*block_width=*/4, relu, sizeof(*output),
|
| 192 |
+
kShiftAmount, replicas, stride, output);
|
| 193 |
+
#endif // __AVX2__
|
| 194 |
+
}
|
| 195 |
+
};
|
| 196 |
+
|
| 197 |
+
} // namespace csrblocksparse
|
| 198 |
+
|
| 199 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
|
sparse_matmul/compute/matmul_fixed_avx2.cc
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include "sparse_matmul/compute/matmul_fixed_avx2.h"
|
| 16 |
+
|
| 17 |
+
#include <cstdint>
|
| 18 |
+
|
| 19 |
+
#if defined __AVX__
|
| 20 |
+
#include <immintrin.h>
|
| 21 |
+
#endif
|
| 22 |
+
|
| 23 |
+
#include "sparse_matmul/compute/matmul.h"
|
| 24 |
+
|
| 25 |
+
namespace csrblocksparse {
|
| 26 |
+
namespace detail {
|
| 27 |
+
|
| 28 |
+
static const int32_t kint32min = static_cast<int32_t>(~0x7FFFFFFF);
|
| 29 |
+
static const int32_t kint32max = static_cast<int32_t>(0x7FFFFFFF);
|
| 30 |
+
|
| 31 |
+
#if defined __AVX2__
|
| 32 |
+
// In-line function computes and returns the result of one row (of blocks) as
|
| 33 |
+
// 4x int32_t. |weights_ptr| is a non-const reference so it can easily be
|
| 34 |
+
// interpreted as belonging to the caller.
|
| 35 |
+
inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs,
|
| 36 |
+
const int16_t* rhs_indices, int nnz,
|
| 37 |
+
int16_t const*& weights_ptr) {
|
| 38 |
+
// Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
|
| 39 |
+
// Zero and 0-3 are the 4x32 bit bias values.
|
| 40 |
+
__m256i sum = _mm256_cvtepu32_epi64(bias128);
|
| 41 |
+
|
| 42 |
+
for (int c = 0; c < nnz; ++c) {
|
| 43 |
+
int rhs_index = rhs_indices[c];
|
| 44 |
+
// Load all 16 weights.
|
| 45 |
+
__m256i weights =
|
| 46 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
| 47 |
+
// Get the 4x int16_t into the bottom of |rhs_64|.
|
| 48 |
+
__m128i rhs_64 = _mm_loadl_epi64(
|
| 49 |
+
reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
|
| 50 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
| 51 |
+
// [0123 0123 0123 0123].
|
| 52 |
+
__m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
|
| 53 |
+
weights_ptr += 16;
|
| 54 |
+
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value));
|
| 55 |
+
}
|
| 56 |
+
// Horizontally add the results. We have 1 register that contains results
|
| 57 |
+
// [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
|
| 58 |
+
// cross lanes, so we end up with [0 1 0 1 2 3 2 3]
|
| 59 |
+
sum = _mm256_hadd_epi32(sum, sum);
|
| 60 |
+
// Permutes the middle two pairs to get the answers together.
|
| 61 |
+
return _mm256_permute4x64_epi64(sum, 0xd8);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Template that allows any fixed combination of OutType and replicas, plus
|
| 65 |
+
// variable |relu|, |shift_out|. Note that |kReplicas| is a template arg as
|
| 66 |
+
// well as a function arg so we can hard-code a limited amount of unrolling.
|
| 67 |
+
template <typename OutType, int kReplicas>
|
| 68 |
+
void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs,
|
| 69 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 70 |
+
const int16_t* rhs_indices, int start_row,
|
| 71 |
+
int end_row, bool relu, int shift_out,
|
| 72 |
+
int replicas, int stride, OutType* output) {
|
| 73 |
+
int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
|
| 74 |
+
__m256i rounding = _mm256_set1_epi32(rounding_addon);
|
| 75 |
+
__m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
|
| 76 |
+
for (int row_block = start_row; row_block < end_row; ++row_block) {
|
| 77 |
+
// Load 4 biases [0 1 2 3].
|
| 78 |
+
__m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias));
|
| 79 |
+
bias += kBlockSize;
|
| 80 |
+
int nnz = nnz_per_row[row_block];
|
| 81 |
+
__m256i sum =
|
| 82 |
+
ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr);
|
| 83 |
+
rhs_indices += nnz;
|
| 84 |
+
// Shift right with rounding to get the right number of mantissa bits.
|
| 85 |
+
sum = _mm256_add_epi32(sum, rounding);
|
| 86 |
+
sum = _mm256_srai_epi32(sum, shift_out);
|
| 87 |
+
// Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
|
| 88 |
+
sum = _mm256_max_epi32(sum, zero);
|
| 89 |
+
if (sizeof(OutType) == 2) {
|
| 90 |
+
// Clip to 16 bit range (with saturation) and pack in the bottom 64
|
| 91 |
+
// bits. The 64 bit result is replicated across the whole 256 bit
|
| 92 |
+
// register. [0123 0123 0123 0123]
|
| 93 |
+
sum = _mm256_packs_epi32(sum, sum);
|
| 94 |
+
int64_t result = _mm256_extract_epi64(sum, 0);
|
| 95 |
+
*reinterpret_cast<int64_t*>(output) = result;
|
| 96 |
+
if (kReplicas > 1) {
|
| 97 |
+
*reinterpret_cast<int64_t*>(output + stride) = result;
|
| 98 |
+
if (kReplicas > 2) {
|
| 99 |
+
for (int r = 2; r < replicas; ++r) {
|
| 100 |
+
*reinterpret_cast<int64_t*>(output + r * stride) = result;
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
} else {
|
| 105 |
+
// Save the lower 128 bits (4x int32_t).
|
| 106 |
+
__m128i result = _mm256_extractf128_si256(sum, 0);
|
| 107 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(output), result);
|
| 108 |
+
if (kReplicas > 1) {
|
| 109 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result);
|
| 110 |
+
if (kReplicas > 2) {
|
| 111 |
+
for (int r = 2; r < replicas; ++r) {
|
| 112 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride),
|
| 113 |
+
result);
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
output += kBlockSize;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// Version that covers all possible combinations of the variable conditions:
|
| 123 |
+
// |relu|, |shift_out|, |replicas|, with int16_t |output|.
|
| 124 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
| 125 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 126 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 127 |
+
bool relu, int shift_out, int replicas, int stride,
|
| 128 |
+
int16_t* output) {
|
| 129 |
+
if (replicas <= 1) {
|
| 130 |
+
MatVec4x4FixedAVX2Template<int16_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
|
| 131 |
+
rhs_indices, start_row, end_row,
|
| 132 |
+
relu, shift_out, 1, stride, output);
|
| 133 |
+
} else if (replicas == 2) {
|
| 134 |
+
MatVec4x4FixedAVX2Template<int16_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
|
| 135 |
+
rhs_indices, start_row, end_row,
|
| 136 |
+
relu, shift_out, 2, stride, output);
|
| 137 |
+
} else {
|
| 138 |
+
MatVec4x4FixedAVX2Template<int16_t, 3>(
|
| 139 |
+
weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
|
| 140 |
+
relu, shift_out, replicas, stride, output);
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// Version that covers all possible combinations of the variable conditions:
|
| 145 |
+
// |relu|, |shift_out|, |replicas|, with int32_t |output|.
|
| 146 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
| 147 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 148 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 149 |
+
bool relu, int shift_out, int replicas, int stride,
|
| 150 |
+
int32_t* output) {
|
| 151 |
+
if (replicas <= 1) {
|
| 152 |
+
MatVec4x4FixedAVX2Template<int32_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
|
| 153 |
+
rhs_indices, start_row, end_row,
|
| 154 |
+
relu, shift_out, 1, stride, output);
|
| 155 |
+
} else if (replicas == 2) {
|
| 156 |
+
MatVec4x4FixedAVX2Template<int32_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
|
| 157 |
+
rhs_indices, start_row, end_row,
|
| 158 |
+
relu, shift_out, 2, stride, output);
|
| 159 |
+
} else {
|
| 160 |
+
MatVec4x4FixedAVX2Template<int32_t, 3>(
|
| 161 |
+
weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
|
| 162 |
+
relu, shift_out, replicas, stride, output);
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
// In-line function computes and returns the result of one row (of blocks) as
|
| 167 |
+
// 8x int32_t. weights_ptr is a non-const reference so it can easily be
|
| 168 |
+
// interpreted as belonging to the caller.
|
| 169 |
+
inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs,
|
| 170 |
+
const int16_t* rhs_indices, int nnz,
|
| 171 |
+
int16_t const*& weights_ptr) {
|
| 172 |
+
// Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
|
| 173 |
+
// Zero and 0-3 are the 4x32 bit bias values from 128 bit half of the input.
|
| 174 |
+
__m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256));
|
| 175 |
+
// Plus 4 more in another sum register from the upper 128 bit half.
|
| 176 |
+
__m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1));
|
| 177 |
+
|
| 178 |
+
for (int c = 0; c < nnz; ++c) {
|
| 179 |
+
int rhs_index = rhs_indices[c];
|
| 180 |
+
// Load all 16 weights.
|
| 181 |
+
__m256i weights =
|
| 182 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
| 183 |
+
// Get the 4x int16_t into the bottom of |rhs_64|.
|
| 184 |
+
__m128i rhs_64 = _mm_loadl_epi64(
|
| 185 |
+
reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
|
| 186 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
| 187 |
+
// [0123 0123 0123 0123].
|
| 188 |
+
__m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
|
| 189 |
+
weights_ptr += 16;
|
| 190 |
+
sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value));
|
| 191 |
+
// Same again for the other 4 results, re-using the same rhs value.
|
| 192 |
+
weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
| 193 |
+
weights_ptr += 16;
|
| 194 |
+
sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value));
|
| 195 |
+
}
|
| 196 |
+
// Horizontally add the results. We have 2 registers that contain results
|
| 197 |
+
// [0 0 1 1 2 2 3 3], and [4 4 5 5 6 6 7 7] but hadd (and almost no other AVX
|
| 198 |
+
// instruction) will not cross lanes, so we end up with [0 1 4 5 2 3 6 7]
|
| 199 |
+
sum1 = _mm256_hadd_epi32(sum1, sum2);
|
| 200 |
+
// Permutes the middle two pairs to get the answers in the right order.
|
| 201 |
+
return _mm256_permute4x64_epi64(sum1, 0xd8);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// Version that covers the main conditions used with 8x4:
|
| 205 |
+
// |relu|, |shift_out|, with int32_t |output|.
|
| 206 |
+
void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
| 207 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 208 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 209 |
+
bool relu, int shift_out, int32_t* output) {
|
| 210 |
+
int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
|
| 211 |
+
__m256i rounding = _mm256_set1_epi32(rounding_addon);
|
| 212 |
+
__m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
|
| 213 |
+
for (int row_block = start_row; row_block < end_row; ++row_block) {
|
| 214 |
+
// Load 4 biases [0 1 2 3 4 5 6 7].
|
| 215 |
+
__m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias));
|
| 216 |
+
bias += kBlockSize * 2;
|
| 217 |
+
int nnz = nnz_per_row[row_block];
|
| 218 |
+
__m256i sum =
|
| 219 |
+
Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr);
|
| 220 |
+
rhs_indices += nnz;
|
| 221 |
+
// Shift right with rounding to get the right number of mantissa bits.
|
| 222 |
+
sum = _mm256_add_epi32(sum, rounding);
|
| 223 |
+
sum = _mm256_srai_epi32(sum, shift_out);
|
| 224 |
+
// Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
|
| 225 |
+
sum = _mm256_max_epi32(sum, zero);
|
| 226 |
+
// Save the all 256 bits (8x int32_t).
|
| 227 |
+
_mm256_store_si256(reinterpret_cast<__m256i*>(output), sum);
|
| 228 |
+
output += kBlockSize * 2;
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
#endif
|
| 233 |
+
|
| 234 |
+
} // namespace detail
|
| 235 |
+
} // namespace csrblocksparse
|
sparse_matmul/compute/matmul_fixed_avx2.h
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
|
| 22 |
+
namespace csrblocksparse {
|
| 23 |
+
namespace detail {
|
| 24 |
+
|
| 25 |
+
// Version that covers all possible combinations of the variable conditions:
|
| 26 |
+
// |relu|, |shift_out|, |replicas|, with int16 output.
|
| 27 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
| 28 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 29 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 30 |
+
bool relu, int shift_out, int replicas, int stride,
|
| 31 |
+
int16_t* output);
|
| 32 |
+
// Version that covers all possible combinations of the variable conditions:
|
| 33 |
+
// |relu|, |shift_out|, |replicas|, with int32 output.
|
| 34 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
| 35 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 36 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 37 |
+
bool relu, int shift_out, int replicas, int stride,
|
| 38 |
+
int32_t* output);
|
| 39 |
+
// Version that covers the main conditions used with 8x4:
|
| 40 |
+
// |relu|, |shift_out|, with int32 output.
|
| 41 |
+
void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
| 42 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 43 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 44 |
+
bool relu, int shift_out, int32_t* output);
|
| 45 |
+
|
| 46 |
+
} // namespace detail
|
| 47 |
+
} // namespace csrblocksparse
|
| 48 |
+
|
| 49 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
|
sparse_matmul/compute/matmul_generic.cc
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include "sparse_matmul/compute/matmul_generic.h"
|
| 16 |
+
|
| 17 |
+
#include <cstdint>
|
| 18 |
+
#include <vector>
|
| 19 |
+
|
| 20 |
+
#include "sparse_matmul/compute/matmul.h"
|
| 21 |
+
|
| 22 |
+
namespace csrblocksparse {
|
| 23 |
+
namespace detail {
|
| 24 |
+
|
| 25 |
+
void MatVecFloatGeneric(const float* weights, const float* rhs,
|
| 26 |
+
const float* bias, const int32_t* nnz_per_row,
|
| 27 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 28 |
+
int block_height, int block_width, bool relu,
|
| 29 |
+
int replicas, int stride, float* output) {
|
| 30 |
+
int weight_index = 0;
|
| 31 |
+
int bias_index = 0;
|
| 32 |
+
std::vector<float> accumulators(block_height);
|
| 33 |
+
for (int row_block = start_row; row_block < end_row;
|
| 34 |
+
++row_block, output += block_height) {
|
| 35 |
+
int nnz = nnz_per_row[row_block];
|
| 36 |
+
// Biases are now stored and used directly without pre-division.
|
| 37 |
+
for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
|
| 38 |
+
|
| 39 |
+
for (int c = 0; c < nnz; ++c) {
|
| 40 |
+
int rhs_index = rhs_indices[c];
|
| 41 |
+
const float* block_rhs = rhs + rhs_index * block_width;
|
| 42 |
+
// Multiply this |block_height| x |block_width| block.
|
| 43 |
+
for (int i = 0; i < block_height; ++i) {
|
| 44 |
+
for (int j = 0; j < block_width; ++j) {
|
| 45 |
+
accumulators[i] += weights[weight_index++] * block_rhs[j];
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
rhs_indices += nnz;
|
| 50 |
+
// Apply relu if desired.
|
| 51 |
+
if (relu) {
|
| 52 |
+
for (int i = 0; i < block_height; ++i) {
|
| 53 |
+
if (accumulators[i] < 0) accumulators[i] = 0;
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
for (int r = 0; r < replicas; ++r) {
|
| 57 |
+
for (int i = 0; i < block_height; ++i) {
|
| 58 |
+
output[i + r * stride] = accumulators[i];
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
|
| 65 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 66 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 67 |
+
int block_height, int block_width, bool relu,
|
| 68 |
+
int bytes_out, int shift_out, int replicas, int stride,
|
| 69 |
+
void* output) {
|
| 70 |
+
int weight_index = 0;
|
| 71 |
+
int bias_index = 0;
|
| 72 |
+
std::vector<int32_t> accumulators(block_height);
|
| 73 |
+
for (int row_block = start_row; row_block < end_row; ++row_block) {
|
| 74 |
+
int nnz = nnz_per_row[row_block];
|
| 75 |
+
// Biases are now stored and used directly without pre-division.
|
| 76 |
+
for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
|
| 77 |
+
|
| 78 |
+
for (int c = 0; c < nnz; ++c) {
|
| 79 |
+
int rhs_index = rhs_indices[c];
|
| 80 |
+
const int16_t* block_rhs = rhs + rhs_index * block_width;
|
| 81 |
+
// Multiply this |block_height| x |block_width| block.
|
| 82 |
+
for (int i = 0; i < block_height; ++i) {
|
| 83 |
+
for (int j = 0; j < block_width; ++j) {
|
| 84 |
+
accumulators[i] += weights[weight_index++] * block_rhs[j];
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
rhs_indices += nnz;
|
| 89 |
+
// Apply relu if desired.
|
| 90 |
+
if (relu) {
|
| 91 |
+
for (int i = 0; i < block_height; ++i) {
|
| 92 |
+
if (accumulators[i] < 0) accumulators[i] = 0;
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
// Output shift.
|
| 96 |
+
if (shift_out > 0) {
|
| 97 |
+
for (int i = 0; i < block_height; ++i) {
|
| 98 |
+
accumulators[i] >>= shift_out;
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
if (bytes_out == 2) {
|
| 102 |
+
int16_t* out16 = reinterpret_cast<int16_t*>(output);
|
| 103 |
+
output = out16 + block_height;
|
| 104 |
+
for (int r = 0; r < replicas; ++r, out16 += stride) {
|
| 105 |
+
for (int i = 0; i < block_height; ++i) {
|
| 106 |
+
out16[i] = accumulators[i];
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
} else {
|
| 110 |
+
int32_t* out32 = reinterpret_cast<int32_t*>(output);
|
| 111 |
+
output = out32 + block_height;
|
| 112 |
+
for (int r = 0; r < replicas; ++r, out32 += stride) {
|
| 113 |
+
for (int i = 0; i < block_height; ++i) {
|
| 114 |
+
out32[i] = accumulators[i];
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
} // namespace detail
|
| 122 |
+
} // namespace csrblocksparse
|
sparse_matmul/compute/matmul_generic.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
|
| 22 |
+
namespace csrblocksparse {
|
| 23 |
+
namespace detail {
|
| 24 |
+
|
| 25 |
+
// Generic version uses plain C++ code.
|
| 26 |
+
void MatVecFloatGeneric(const float* weights, const float* rhs,
|
| 27 |
+
const float* bias, const int32_t* nnz_per_row,
|
| 28 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 29 |
+
int block_height, int block_width, bool relu,
|
| 30 |
+
int replicas, int stride, float* output);
|
| 31 |
+
void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
|
| 32 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
| 33 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
| 34 |
+
int block_height, int block_width, bool relu,
|
| 35 |
+
int bytes_out, int shift_out, int replicas, int stride,
|
| 36 |
+
void* output);
|
| 37 |
+
|
| 38 |
+
} // namespace detail
|
| 39 |
+
} // namespace csrblocksparse
|
| 40 |
+
|
| 41 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
|
sparse_matmul/compute/thread_bounds.cc
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include "sparse_matmul/compute/thread_bounds.h"
|
| 16 |
+
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
#include "glog/logging.h"
|
| 20 |
+
|
| 21 |
+
namespace csrblocksparse {
|
| 22 |
+
|
| 23 |
+
void ThreadBounds::PrepareForThreads(int block_width, int block_height,
|
| 24 |
+
int num_threads,
|
| 25 |
+
int reduced_rows_per_cache_row,
|
| 26 |
+
int reduced_rows, const int* nnz_per_row) {
|
| 27 |
+
CHECK_GT(num_threads, 0);
|
| 28 |
+
block_width_ = block_width;
|
| 29 |
+
block_height_ = block_height;
|
| 30 |
+
ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row,
|
| 31 |
+
reduced_rows, nnz_per_row);
|
| 32 |
+
weight_starts_.clear();
|
| 33 |
+
rhs_indices_starts_.clear();
|
| 34 |
+
bias_starts_.clear();
|
| 35 |
+
weight_starts_.reserve(row_starts_.size());
|
| 36 |
+
rhs_indices_starts_.reserve(row_starts_.size());
|
| 37 |
+
bias_starts_.reserve(row_starts_.size());
|
| 38 |
+
|
| 39 |
+
// Compute the start indices of each of the types, given what we know about
|
| 40 |
+
// padding, and number of |nnz_per_row|.
|
| 41 |
+
int weight_index = 0;
|
| 42 |
+
int rhs_indices_index = 0;
|
| 43 |
+
int bias_index = 0;
|
| 44 |
+
int row = 0;
|
| 45 |
+
for (int start : row_starts_) {
|
| 46 |
+
while (row < start) {
|
| 47 |
+
weight_index += nnz_per_row[row] * block_width_ * block_height_;
|
| 48 |
+
rhs_indices_index += nnz_per_row[row];
|
| 49 |
+
bias_index += block_height_;
|
| 50 |
+
++row;
|
| 51 |
+
}
|
| 52 |
+
weight_starts_.push_back(weight_index);
|
| 53 |
+
rhs_indices_starts_.push_back(rhs_indices_index);
|
| 54 |
+
bias_starts_.push_back(bias_index);
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// Computes the block row (reduced) index of the start of each thread.
|
| 59 |
+
void ThreadBounds::ComputeThreadSplitPoints(int num_threads,
|
| 60 |
+
int reduced_rows_per_cache_row,
|
| 61 |
+
int reduced_rows,
|
| 62 |
+
const int* nnz_per_row) {
|
| 63 |
+
row_starts_.assign(/*n=*/1, /*val=*/0);
|
| 64 |
+
// Break the rule if the matrix is too small to allow one per thread, which
|
| 65 |
+
// occurs only during tests.
|
| 66 |
+
if (reduced_rows_per_cache_row * num_threads > reduced_rows)
|
| 67 |
+
reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1);
|
| 68 |
+
int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) /
|
| 69 |
+
reduced_rows_per_cache_row;
|
| 70 |
+
|
| 71 |
+
// Compute exclusive prefix sum of the amount of work per row.
|
| 72 |
+
std::vector<int> work_upto_row(cache_rows + 1, 0);
|
| 73 |
+
int extra_row_work = 2 * reduced_rows_per_cache_row;
|
| 74 |
+
for (int i = 0; i < cache_rows; ++i) {
|
| 75 |
+
int new_nnz = 0;
|
| 76 |
+
for (int j = 0; j < reduced_rows_per_cache_row; ++j) {
|
| 77 |
+
// if |reduced_rows_per_cache_row| isn't an exact multiple of the
|
| 78 |
+
// matrix size, then we need to be careful here.
|
| 79 |
+
int index = i * reduced_rows_per_cache_row + j;
|
| 80 |
+
if (index < reduced_rows) new_nnz += nnz_per_row[index];
|
| 81 |
+
}
|
| 82 |
+
work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i];
|
| 83 |
+
}
|
| 84 |
+
int total_work = work_upto_row.back();
|
| 85 |
+
// Find the split point point based on assigned approximately equal amount
|
| 86 |
+
// of work for each thread.
|
| 87 |
+
int prev_split = 0;
|
| 88 |
+
for (int i = 1; i <= num_threads; ++i) {
|
| 89 |
+
int split = std::distance(
|
| 90 |
+
work_upto_row.begin(),
|
| 91 |
+
std::lower_bound(work_upto_row.begin(), work_upto_row.end(),
|
| 92 |
+
i * total_work / num_threads));
|
| 93 |
+
int split_row = split * reduced_rows_per_cache_row;
|
| 94 |
+
if (i == num_threads) {
|
| 95 |
+
split_row = reduced_rows;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back()
|
| 99 |
+
<< " work=" << work_upto_row[split] - work_upto_row[prev_split];
|
| 100 |
+
row_starts_.push_back(split_row);
|
| 101 |
+
prev_split = split;
|
| 102 |
+
}
|
| 103 |
+
VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
} // namespace csrblocksparse
|
sparse_matmul/compute/thread_bounds.h
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
|
| 19 |
+
|
| 20 |
+
#include <vector>
|
| 21 |
+
|
| 22 |
+
namespace csrblocksparse {
|
| 23 |
+
|
| 24 |
+
// Class to compute and store the bounds of each thread used in a computation,
|
| 25 |
+
// and to provide corresponding spans of vectors.
|
| 26 |
+
class ThreadBounds {
|
| 27 |
+
public:
|
| 28 |
+
ThreadBounds() : block_width_(0), block_height_(0) {}
|
| 29 |
+
|
| 30 |
+
void PrepareForThreads(int block_width, int block_height, int num_threads,
|
| 31 |
+
int reduced_rows_per_cache_row, int reduced_rows,
|
| 32 |
+
const int* nnz_per_row);
|
| 33 |
+
|
| 34 |
+
// Functions that offset the appropriate type to the start of the data
|
| 35 |
+
// needed by the given thread id (|tid|).
|
| 36 |
+
template <typename WeightType>
|
| 37 |
+
const WeightType* OffsetWeights(const WeightType* weights, int tid) const {
|
| 38 |
+
return weights + weight_starts_[tid];
|
| 39 |
+
}
|
| 40 |
+
template <typename RhsIndType>
|
| 41 |
+
const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices,
|
| 42 |
+
int tid) const {
|
| 43 |
+
return rhs_indices + rhs_indices_starts_[tid];
|
| 44 |
+
}
|
| 45 |
+
template <typename BiasType>
|
| 46 |
+
const BiasType* OffsetBias(const BiasType* bias, int tid) const {
|
| 47 |
+
return bias + bias_starts_[tid];
|
| 48 |
+
}
|
| 49 |
+
template <typename OutType>
|
| 50 |
+
OutType* OffsetOutput(OutType* output, int tid) const {
|
| 51 |
+
return output + block_height_ * row_starts_[tid];
|
| 52 |
+
}
|
| 53 |
+
int StartRow(int tid) const { return row_starts_[tid]; }
|
| 54 |
+
const std::vector<int>& row_starts() const { return row_starts_; }
|
| 55 |
+
|
| 56 |
+
private:
|
| 57 |
+
// Computes the block row (reduced) index of the start of each thread.
|
| 58 |
+
void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row,
|
| 59 |
+
int reduced_rows, const int* nnz_per_row);
|
| 60 |
+
|
| 61 |
+
// Sizes of a sparse block.
|
| 62 |
+
int block_width_;
|
| 63 |
+
int block_height_;
|
| 64 |
+
// Start indices of each data type by thread-id with an extra value at the
|
| 65 |
+
// end.
|
| 66 |
+
std::vector<int> row_starts_;
|
| 67 |
+
std::vector<int> weight_starts_;
|
| 68 |
+
std::vector<int> rhs_indices_starts_;
|
| 69 |
+
std::vector<int> bias_starts_;
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
} // namespace csrblocksparse
|
| 73 |
+
|
| 74 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
|
sparse_matmul/layers/BUILD
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sparse/Masked Matrix and Layer.
|
| 2 |
+
|
| 3 |
+
# [internal] load android_library_selector
|
| 4 |
+
# [internal] load android_cc_test:def.bzl
|
| 5 |
+
|
| 6 |
+
licenses(["notice"])
|
| 7 |
+
|
| 8 |
+
cc_library(
|
| 9 |
+
name = "layer",
|
| 10 |
+
hdrs = [
|
| 11 |
+
"sparse_linear_layer.h",
|
| 12 |
+
],
|
| 13 |
+
visibility = [
|
| 14 |
+
"//sparse_matmul:__subpackages__",
|
| 15 |
+
],
|
| 16 |
+
deps = [
|
| 17 |
+
":matrix",
|
| 18 |
+
"//sparse_matmul/numerics:types",
|
| 19 |
+
"//sparse_matmul/os:coop_threads",
|
| 20 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
| 21 |
+
"@com_google_absl//absl/memory",
|
| 22 |
+
"@com_google_absl//absl/strings:str_format",
|
| 23 |
+
"@com_google_glog//:glog",
|
| 24 |
+
],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
cc_library(
|
| 28 |
+
name = "matrix",
|
| 29 |
+
hdrs = [
|
| 30 |
+
"csr_blocksparse_matrix.h",
|
| 31 |
+
"masked_sparse_matrix.h",
|
| 32 |
+
],
|
| 33 |
+
visibility = [
|
| 34 |
+
"//sparse_matmul:__subpackages__",
|
| 35 |
+
],
|
| 36 |
+
deps = [
|
| 37 |
+
"//sparse_matmul/compute:kernels",
|
| 38 |
+
"//sparse_matmul/compute:matmul",
|
| 39 |
+
"//sparse_matmul/compute:thread_bounds",
|
| 40 |
+
"//sparse_matmul/numerics:types",
|
| 41 |
+
"//sparse_matmul/os:coop_threads",
|
| 42 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
| 43 |
+
"@com_google_absl//absl/memory",
|
| 44 |
+
"@com_google_absl//absl/strings:str_format",
|
| 45 |
+
"@com_google_glog//:glog",
|
| 46 |
+
],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
cc_library(
|
| 50 |
+
name = "utils",
|
| 51 |
+
srcs = [
|
| 52 |
+
"utils.cc",
|
| 53 |
+
],
|
| 54 |
+
hdrs = [
|
| 55 |
+
"read_array_ifstream.h",
|
| 56 |
+
"utils.h",
|
| 57 |
+
],
|
| 58 |
+
visibility = [
|
| 59 |
+
"//sparse_matmul:__subpackages__",
|
| 60 |
+
],
|
| 61 |
+
deps = [
|
| 62 |
+
":layer",
|
| 63 |
+
":matrix",
|
| 64 |
+
":status",
|
| 65 |
+
"//sparse_matmul/numerics:types",
|
| 66 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
| 67 |
+
"//sparse_matmul/zlib_wrapper",
|
| 68 |
+
"@com_google_absl//absl/status",
|
| 69 |
+
"@com_google_absl//absl/strings",
|
| 70 |
+
"@com_google_absl//absl/strings:cord",
|
| 71 |
+
"@gulrak_filesystem//:filesystem",
|
| 72 |
+
],
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
cc_library(
|
| 76 |
+
name = "status",
|
| 77 |
+
srcs = [
|
| 78 |
+
"errno_mapping.cc",
|
| 79 |
+
],
|
| 80 |
+
hdrs = [
|
| 81 |
+
"errno_mapping.h",
|
| 82 |
+
"status_macros.h",
|
| 83 |
+
],
|
| 84 |
+
deps = [
|
| 85 |
+
"@com_google_absl//absl/status",
|
| 86 |
+
"@com_google_absl//absl/status:statusor",
|
| 87 |
+
"@com_google_absl//absl/strings",
|
| 88 |
+
"@com_google_absl//absl/strings:cord",
|
| 89 |
+
],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
cc_test(
|
| 93 |
+
name = "csrblocksparse_test",
|
| 94 |
+
size = "small",
|
| 95 |
+
srcs = [
|
| 96 |
+
"csrblocksparse_test.cc",
|
| 97 |
+
],
|
| 98 |
+
data = glob(["testdata/*"]),
|
| 99 |
+
linkopts = select({
|
| 100 |
+
"@bazel_tools//platforms:android": ["-landroid"],
|
| 101 |
+
"//conditions:default": [],
|
| 102 |
+
}),
|
| 103 |
+
shard_count = 10,
|
| 104 |
+
deps = [
|
| 105 |
+
":status",
|
| 106 |
+
":utils",
|
| 107 |
+
"//sparse_matmul/compute:matmul",
|
| 108 |
+
"//sparse_matmul/numerics:test_utils",
|
| 109 |
+
"//sparse_matmul/os:coop_threads",
|
| 110 |
+
"@com_google_absl//absl/status",
|
| 111 |
+
"@com_google_absl//absl/strings",
|
| 112 |
+
"@com_google_absl//absl/types:span",
|
| 113 |
+
"@com_google_googletest//:gtest_main",
|
| 114 |
+
"@gulrak_filesystem//:filesystem",
|
| 115 |
+
],
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
cc_test(
|
| 119 |
+
name = "sparse_linear_layer_test",
|
| 120 |
+
srcs = [
|
| 121 |
+
"sparse_linear_layer_test.cc",
|
| 122 |
+
],
|
| 123 |
+
deps = [
|
| 124 |
+
":layer",
|
| 125 |
+
"//sparse_matmul/numerics:test_utils",
|
| 126 |
+
"@com_google_googletest//:gtest_main",
|
| 127 |
+
],
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
cc_test(
|
| 131 |
+
name = "utils_test",
|
| 132 |
+
srcs = ["utils_test.cc"],
|
| 133 |
+
deps = [
|
| 134 |
+
":layer",
|
| 135 |
+
":matrix",
|
| 136 |
+
":status",
|
| 137 |
+
":utils",
|
| 138 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
| 139 |
+
"//sparse_matmul/numerics:test_utils",
|
| 140 |
+
"//sparse_matmul/numerics:types",
|
| 141 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
| 142 |
+
"@com_google_absl//absl/flags:flag",
|
| 143 |
+
"@com_google_googletest//:gtest_main",
|
| 144 |
+
"@gulrak_filesystem//:filesystem",
|
| 145 |
+
],
|
| 146 |
+
)
|
sparse_matmul/layers/csr_blocksparse_matrix.h
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
|
| 19 |
+
|
| 20 |
+
#include <algorithm>
|
| 21 |
+
#include <cstdint>
|
| 22 |
+
#include <iostream>
|
| 23 |
+
#include <memory>
|
| 24 |
+
#include <tuple>
|
| 25 |
+
#include <vector>
|
| 26 |
+
|
| 27 |
+
#include "glog/logging.h"
|
| 28 |
+
// IWYU pragma: begin_exports
|
| 29 |
+
#include "sparse_matmul/compute/kernels_generic.h"
|
| 30 |
+
#include "sparse_matmul/compute/matmul.h"
|
| 31 |
+
#include "sparse_matmul/compute/thread_bounds.h"
|
| 32 |
+
#include "sparse_matmul/layers/masked_sparse_matrix.h"
|
| 33 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
| 34 |
+
#include "sparse_matmul/numerics/float16_types.h"
|
| 35 |
+
#include "sparse_matmul/os/coop_threads.h"
|
| 36 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
| 37 |
+
// IWYU pragma: end_exports
|
| 38 |
+
#include "absl/memory/memory.h"
|
| 39 |
+
|
| 40 |
+
namespace csrblocksparse {
|
| 41 |
+
// CsrBlockSparseMatrix stores a modified block compressed sparse row
|
| 42 |
+
// representation of a sparse matrix. The ordering of the weights is modified
|
| 43 |
+
// in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively)
|
| 44 |
+
// of columns of weights are stored contiguously before moving on to the next
|
| 45 |
+
// row. The 4x4 case stores each block contiguously.
|
| 46 |
+
//
|
| 47 |
+
// Currently it is constructed from a MaskedSparseMatrix which usees a dense
|
| 48 |
+
// binary mask representation. The construction generates the compressed
|
| 49 |
+
// representation. Further iterations will support a direct serialization
|
| 50 |
+
// of the compressed representation.
|
| 51 |
+
//
|
| 52 |
+
// MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values)
|
| 53 |
+
// CsrBlockSparseMatrix matrix(masked_matrix)
|
| 54 |
+
//
|
| 55 |
+
// matrix.SpMV_bias(rhs, bias, &out);
|
| 56 |
+
//
|
| 57 |
+
// This class is thread compatible.
|
| 58 |
+
template <typename WeightType, typename RhsType, typename DeltaType = int16_t>
|
| 59 |
+
class CsrBlockSparseMatrix {
|
| 60 |
+
public:
|
| 61 |
+
CsrBlockSparseMatrix() {}
|
| 62 |
+
|
| 63 |
+
// Reference used to indicate that this is an input and not an output.
|
| 64 |
+
CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) {
|
| 65 |
+
ReadFromFlatBuffer(buffer, len);
|
| 66 |
+
ComputeRHSIndices();
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <typename InputType>
|
| 70 |
+
CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) {
|
| 71 |
+
sparsity_ = masked_matrix.sparsity();
|
| 72 |
+
rows_ = masked_matrix.rows();
|
| 73 |
+
cols_ = masked_matrix.cols();
|
| 74 |
+
|
| 75 |
+
DetermineBlockSize(masked_matrix);
|
| 76 |
+
|
| 77 |
+
if (block_width_ == 1 && block_height_ == 1)
|
| 78 |
+
col_multiple_ = 8;
|
| 79 |
+
else
|
| 80 |
+
col_multiple_ = 1;
|
| 81 |
+
|
| 82 |
+
std::vector<InputType> weights(masked_matrix.values().begin(),
|
| 83 |
+
masked_matrix.values().end());
|
| 84 |
+
|
| 85 |
+
reduced_rows_ = (rows_ + block_height_ - 1) / block_height_;
|
| 86 |
+
rows_ = reduced_rows_ * block_height_;
|
| 87 |
+
reduced_cols_ = cols_ / block_width_;
|
| 88 |
+
|
| 89 |
+
// Calculate the reduced CSR representation of the matrix.
|
| 90 |
+
std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_);
|
| 91 |
+
std::vector<int> row_offsets = {0};
|
| 92 |
+
int nnz = 0;
|
| 93 |
+
const auto& mask = masked_matrix.mask();
|
| 94 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
| 95 |
+
for (int c = 0; c < reduced_cols_; ++c) {
|
| 96 |
+
int mask_val = mask[r * block_height_ * cols_ + c * block_width_];
|
| 97 |
+
reduced_mask[r * reduced_cols_ + c] = mask_val;
|
| 98 |
+
nnz += mask_val;
|
| 99 |
+
}
|
| 100 |
+
row_offsets.push_back(nnz);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
// Make sure the reduced representation has the correct number of columns.
|
| 104 |
+
MakeColumnsMultiple(row_offsets, &reduced_mask, &weights);
|
| 105 |
+
|
| 106 |
+
std::vector<int> col_indices;
|
| 107 |
+
std::vector<WeightType> weights_csr;
|
| 108 |
+
std::vector<int> nnz_per_row;
|
| 109 |
+
MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices,
|
| 110 |
+
&weights_csr);
|
| 111 |
+
|
| 112 |
+
// Generate column deltas from |col_indices|.
|
| 113 |
+
std::vector<DeltaType> col_deltas;
|
| 114 |
+
for (int i = 0; i < col_indices.size(); ++i) {
|
| 115 |
+
// |col_indices| are used to index the RHS vector which is always float.
|
| 116 |
+
int64_t diff = sizeof(RhsType);
|
| 117 |
+
if (i == 0)
|
| 118 |
+
diff *= block_width_ * (col_indices[i]);
|
| 119 |
+
else
|
| 120 |
+
diff *= block_width_ * (col_indices[i] - col_indices[i - 1]);
|
| 121 |
+
|
| 122 |
+
CHECK(diff < std::numeric_limits<DeltaType>::max())
|
| 123 |
+
<< "delta between column indices in bytes " << diff
|
| 124 |
+
<< " exceeded the maximum size of the DeltaType "
|
| 125 |
+
<< std::numeric_limits<DeltaType>::max();
|
| 126 |
+
col_deltas.push_back(static_cast<DeltaType>(diff));
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Because of pre-fetching we need some extra values at the end.
|
| 130 |
+
col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0);
|
| 131 |
+
nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back());
|
| 132 |
+
|
| 133 |
+
weights_ = CacheAlignedVector<WeightType>(weights_csr);
|
| 134 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
|
| 135 |
+
nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row);
|
| 136 |
+
ComputeRHSIndices();
|
| 137 |
+
|
| 138 |
+
num_threads_ = 0;
|
| 139 |
+
PrepareForThreads(1);
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// Constructor makes a matrix from the given weights, deltas and nnz, taking
|
| 143 |
+
// the other parameters from |src_matrix|. |cols| is the number of raw columns
|
| 144 |
+
// (NOT blocks) of the new matrix.
|
| 145 |
+
CsrBlockSparseMatrix(
|
| 146 |
+
const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix,
|
| 147 |
+
const std::vector<WeightType>& new_weights,
|
| 148 |
+
const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz,
|
| 149 |
+
int cols) {
|
| 150 |
+
num_threads_ = 0;
|
| 151 |
+
col_multiple_ = src_matrix.col_multiple_;
|
| 152 |
+
block_width_ = src_matrix.block_width_;
|
| 153 |
+
block_height_ = src_matrix.block_height_;
|
| 154 |
+
reduced_rows_ = new_nnz.size();
|
| 155 |
+
rows_ = reduced_rows_ * block_height_;
|
| 156 |
+
cols_ = cols;
|
| 157 |
+
reduced_cols_ = cols_ / block_width_;
|
| 158 |
+
weights_ = CacheAlignedVector<WeightType>(new_weights);
|
| 159 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas);
|
| 160 |
+
nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
|
| 161 |
+
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
|
| 162 |
+
ComputeRHSIndices();
|
| 163 |
+
name_ = src_matrix.name_;
|
| 164 |
+
PrepareForThreads(1);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
// Factory method takes a column slice out of *this and returns a sparse
|
| 168 |
+
// matrix that takes as inputs [|start_col|, |end_col|) of *this, and
|
| 169 |
+
// returns the same number of outputs, but only a partial result.
|
| 170 |
+
// If |keep_rhs_size|, then the new matrix takes the same rhs as the current
|
| 171 |
+
// matrix, but uses a subset of it, instead of expecting just the reduced rhs.
|
| 172 |
+
// If |start_col| > |end_col|, then we slice out the complement of the defined
|
| 173 |
+
// interval, ie [0, |end_col|) + [|start_col|, current end).
|
| 174 |
+
// NOTE That |start_col| and |end_col| are in raw column coordinates, NOT
|
| 175 |
+
// block units.
|
| 176 |
+
CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col,
|
| 177 |
+
bool keep_rhs_size = false) const {
|
| 178 |
+
int weight_index = 0;
|
| 179 |
+
int delta_index = 0;
|
| 180 |
+
std::vector<DeltaType> new_deltas;
|
| 181 |
+
std::vector<WeightType> new_weights;
|
| 182 |
+
std::vector<int> new_nnz(reduced_rows_);
|
| 183 |
+
int col = 0;
|
| 184 |
+
int prev_col = keep_rhs_size ? 0 : start_col;
|
| 185 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
| 186 |
+
int reduced_col_count = nnz_per_row_[r];
|
| 187 |
+
for (int c = 0; c < reduced_col_count; ++c, ++delta_index) {
|
| 188 |
+
col += col_deltas_[delta_index] / sizeof(RhsType);
|
| 189 |
+
if ((start_col < end_col && start_col <= col && col < end_col) ||
|
| 190 |
+
(start_col > end_col && (col < end_col || col >= start_col))) {
|
| 191 |
+
++new_nnz[r];
|
| 192 |
+
new_deltas.push_back((col - prev_col) * sizeof(RhsType));
|
| 193 |
+
prev_col = col;
|
| 194 |
+
for (int i = 0; i < block_width_ * block_height_;
|
| 195 |
+
++i, ++weight_index) {
|
| 196 |
+
new_weights.push_back(weights_[weight_index]);
|
| 197 |
+
}
|
| 198 |
+
} else {
|
| 199 |
+
weight_index += block_width_ * block_height_;
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
int new_cols = keep_rhs_size ? cols_ : end_col - start_col;
|
| 204 |
+
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz,
|
| 205 |
+
new_cols);
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// Factory method takes a row slice out of *this and returns a sparse
|
| 209 |
+
// matrix that takes the sampe inputs as *this, and returns the outputs for
|
| 210 |
+
// the range [|start_row|, |end_row|).
|
| 211 |
+
// NOTE That |start_row| and |end_row| are in raw column coordinates, NOT
|
| 212 |
+
// block units.
|
| 213 |
+
CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const {
|
| 214 |
+
int start_reduced = start_row / block_height_;
|
| 215 |
+
int end_reduced = end_row / block_height_;
|
| 216 |
+
std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced,
|
| 217 |
+
nnz_per_row_.data() + end_reduced);
|
| 218 |
+
int weight_start = 0;
|
| 219 |
+
for (int r = 0; r < start_reduced; ++r) {
|
| 220 |
+
weight_start += nnz_per_row_[r];
|
| 221 |
+
}
|
| 222 |
+
int weight_end = weight_start;
|
| 223 |
+
for (int r = start_reduced; r < end_reduced; ++r) {
|
| 224 |
+
weight_end += nnz_per_row_[r];
|
| 225 |
+
}
|
| 226 |
+
int delta_start = 0;
|
| 227 |
+
for (int i = 0; i < weight_start; ++i) {
|
| 228 |
+
delta_start += col_deltas_[i];
|
| 229 |
+
}
|
| 230 |
+
std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start,
|
| 231 |
+
col_deltas_.data() + weight_end);
|
| 232 |
+
new_deltas[0] += delta_start;
|
| 233 |
+
int block_size = block_height_ * block_width_;
|
| 234 |
+
std::vector<WeightType> new_weights(
|
| 235 |
+
weights_.data() + weight_start * block_size,
|
| 236 |
+
weights_.data() + weight_end * block_size);
|
| 237 |
+
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
// Combines adjacent row blocks, doubling the block height.
|
| 241 |
+
// This necessarily involves adding zero weights where the blocks don't align
|
| 242 |
+
// across adjacent pairs of rows, so use with caution, as the resulting matrix
|
| 243 |
+
// is most likely to run slower if very sparse to begin with.
|
| 244 |
+
// In the few cases where the blocks do mostly align, the resulting matmul
|
| 245 |
+
// could be much faster, as the number of reads of the rhs will be halved.
|
| 246 |
+
void DoubleBlockHeight() {
|
| 247 |
+
int new_rows = reduced_rows_ / 2;
|
| 248 |
+
std::vector<int> new_nnz(new_rows);
|
| 249 |
+
std::vector<DeltaType> new_rhs_indices;
|
| 250 |
+
std::vector<WeightType> new_weights;
|
| 251 |
+
int rhs_index1 = 0;
|
| 252 |
+
int rhs_index2 = 0;
|
| 253 |
+
int block_size = block_height_ * block_width_;
|
| 254 |
+
for (int r = 0; r < new_rows; ++r) {
|
| 255 |
+
int start_nnz = new_rhs_indices.size();
|
| 256 |
+
rhs_index2 += nnz_per_row_[r * 2];
|
| 257 |
+
int end1 = rhs_index1 + nnz_per_row_[r * 2];
|
| 258 |
+
int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1];
|
| 259 |
+
// Run over a pair of rows with 2 iterators, combining blocks as we go, or
|
| 260 |
+
// padding with zeros where the block positions don't match.
|
| 261 |
+
while (rhs_index1 < end1 || rhs_index2 < end2) {
|
| 262 |
+
int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_;
|
| 263 |
+
int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_;
|
| 264 |
+
if (col1 < col2) {
|
| 265 |
+
// Need zero weights for row2 to pad out weights block.
|
| 266 |
+
new_rhs_indices.push_back(col1);
|
| 267 |
+
new_weights.insert(new_weights.end(),
|
| 268 |
+
weights_.data() + rhs_index1 * block_size,
|
| 269 |
+
weights_.data() + (rhs_index1 + 1) * block_size);
|
| 270 |
+
new_weights.insert(new_weights.end(), block_size,
|
| 271 |
+
static_cast<WeightType>(0.0f));
|
| 272 |
+
++rhs_index1;
|
| 273 |
+
} else if (col1 > col2) {
|
| 274 |
+
// Need zero weights for row1 to pad out weights block.
|
| 275 |
+
new_rhs_indices.push_back(col2);
|
| 276 |
+
new_weights.insert(new_weights.end(), block_size,
|
| 277 |
+
static_cast<WeightType>(0.0f));
|
| 278 |
+
new_weights.insert(new_weights.end(),
|
| 279 |
+
weights_.data() + rhs_index2 * block_size,
|
| 280 |
+
weights_.data() + (rhs_index2 + 1) * block_size);
|
| 281 |
+
++rhs_index2;
|
| 282 |
+
} else {
|
| 283 |
+
// Combine weights for both row1 and row2.
|
| 284 |
+
new_rhs_indices.push_back(col1);
|
| 285 |
+
new_weights.insert(new_weights.end(),
|
| 286 |
+
weights_.data() + rhs_index1 * block_size,
|
| 287 |
+
weights_.data() + (rhs_index1 + 1) * block_size);
|
| 288 |
+
new_weights.insert(new_weights.end(),
|
| 289 |
+
weights_.data() + rhs_index2 * block_size,
|
| 290 |
+
weights_.data() + (rhs_index2 + 1) * block_size);
|
| 291 |
+
++rhs_index1;
|
| 292 |
+
++rhs_index2;
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
rhs_index1 = rhs_index2;
|
| 296 |
+
new_nnz[r] = new_rhs_indices.size() - start_nnz;
|
| 297 |
+
}
|
| 298 |
+
block_height_ *= 2;
|
| 299 |
+
reduced_rows_ /= 2;
|
| 300 |
+
weights_ = CacheAlignedVector<WeightType>(new_weights);
|
| 301 |
+
rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices);
|
| 302 |
+
nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
|
| 303 |
+
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
|
| 304 |
+
ComputeColDeltas();
|
| 305 |
+
if (num_threads_ > 0) {
|
| 306 |
+
int num_threads = num_threads_;
|
| 307 |
+
num_threads_ = 0;
|
| 308 |
+
PrepareForThreads(num_threads);
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
// Allocates memory and fills buffer.
|
| 313 |
+
// Caller is responsible for the memory de-allocation.
|
| 314 |
+
// TODO(b/189958858): Both Read and Write need to eventually handle the
|
| 315 |
+
// different possible HalfType and DeltaType values, but punting for now as
|
| 316 |
+
// there is only one supported combination.
|
| 317 |
+
std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) {
|
| 318 |
+
std::size_t bytes = 0;
|
| 319 |
+
bytes += FixedParameterSize();
|
| 320 |
+
bytes += weights_.size() * sizeof(WeightType);
|
| 321 |
+
bytes += col_deltas_.size() * sizeof(DeltaType);
|
| 322 |
+
bytes += nnz_per_row_.size() * sizeof(int);
|
| 323 |
+
|
| 324 |
+
uint8_t* bytes_ptr_ptr =
|
| 325 |
+
reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes)));
|
| 326 |
+
|
| 327 |
+
int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr);
|
| 328 |
+
|
| 329 |
+
*int_bytes_ptr++ = rows_;
|
| 330 |
+
*int_bytes_ptr++ = cols_;
|
| 331 |
+
*int_bytes_ptr++ = reduced_rows_;
|
| 332 |
+
*int_bytes_ptr++ = reduced_cols_;
|
| 333 |
+
*int_bytes_ptr++ = block_width_;
|
| 334 |
+
*int_bytes_ptr++ = block_height_;
|
| 335 |
+
*int_bytes_ptr++ = col_multiple_;
|
| 336 |
+
*int_bytes_ptr++ = num_threads_;
|
| 337 |
+
*int_bytes_ptr++ = weights_.size();
|
| 338 |
+
*int_bytes_ptr++ = col_deltas_.size();
|
| 339 |
+
*int_bytes_ptr++ = nnz_per_row_.size();
|
| 340 |
+
|
| 341 |
+
float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr);
|
| 342 |
+
*float_bytes_ptr++ = sparsity_;
|
| 343 |
+
|
| 344 |
+
uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr);
|
| 345 |
+
|
| 346 |
+
memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType));
|
| 347 |
+
bytes_ptr += weights_.size() * sizeof(WeightType);
|
| 348 |
+
|
| 349 |
+
memcpy(bytes_ptr, col_deltas_.data(),
|
| 350 |
+
col_deltas_.size() * sizeof(DeltaType));
|
| 351 |
+
bytes_ptr += col_deltas_.size() * sizeof(DeltaType);
|
| 352 |
+
|
| 353 |
+
memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int));
|
| 354 |
+
bytes_ptr += nnz_per_row_.size() * sizeof(int);
|
| 355 |
+
|
| 356 |
+
csr_flatbuffer->resize(bytes);
|
| 357 |
+
csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes);
|
| 358 |
+
free(bytes_ptr_ptr);
|
| 359 |
+
|
| 360 |
+
return bytes;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) {
|
| 364 |
+
CHECK_GE(len, FixedParameterSize());
|
| 365 |
+
|
| 366 |
+
const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes);
|
| 367 |
+
rows_ = *int_bytes_ptr++;
|
| 368 |
+
cols_ = *int_bytes_ptr++;
|
| 369 |
+
reduced_rows_ = *int_bytes_ptr++;
|
| 370 |
+
reduced_cols_ = *int_bytes_ptr++;
|
| 371 |
+
block_width_ = *int_bytes_ptr++;
|
| 372 |
+
block_height_ = *int_bytes_ptr++;
|
| 373 |
+
col_multiple_ = *int_bytes_ptr++;
|
| 374 |
+
int num_threads = *int_bytes_ptr++;
|
| 375 |
+
int32_t weights_size = *int_bytes_ptr++;
|
| 376 |
+
int32_t col_deltas_size = *int_bytes_ptr++;
|
| 377 |
+
int32_t nnz_per_row_size = *int_bytes_ptr++;
|
| 378 |
+
|
| 379 |
+
// Make sure negative sizes don't mess things up.
|
| 380 |
+
weights_size = std::max(0, weights_size);
|
| 381 |
+
col_deltas_size = std::max(0, col_deltas_size);
|
| 382 |
+
nnz_per_row_size = std::max(0, nnz_per_row_size);
|
| 383 |
+
|
| 384 |
+
const float* float_bytes_ptr =
|
| 385 |
+
reinterpret_cast<const float*>(int_bytes_ptr);
|
| 386 |
+
sparsity_ = *float_bytes_ptr++;
|
| 387 |
+
|
| 388 |
+
std::size_t total_bytes =
|
| 389 |
+
FixedParameterSize() + weights_size * sizeof(WeightType) +
|
| 390 |
+
col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int);
|
| 391 |
+
|
| 392 |
+
CHECK_EQ(total_bytes, len)
|
| 393 |
+
<< "total bytes: " << total_bytes << ", actual len given: " << len;
|
| 394 |
+
|
| 395 |
+
const uint8_t* bytes_ptr =
|
| 396 |
+
reinterpret_cast<const uint8_t*>(float_bytes_ptr);
|
| 397 |
+
std::vector<WeightType> weights_raw(weights_size);
|
| 398 |
+
memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType));
|
| 399 |
+
weights_ = CacheAlignedVector<WeightType>(weights_raw);
|
| 400 |
+
bytes_ptr += weights_size * sizeof(WeightType);
|
| 401 |
+
|
| 402 |
+
std::vector<DeltaType> deltas_raw(col_deltas_size);
|
| 403 |
+
memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType));
|
| 404 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw);
|
| 405 |
+
bytes_ptr += col_deltas_size * sizeof(DeltaType);
|
| 406 |
+
|
| 407 |
+
std::vector<int> nnz_raw(nnz_per_row_size);
|
| 408 |
+
memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int));
|
| 409 |
+
nnz_per_row_ = CacheAlignedVector<int>(nnz_raw);
|
| 410 |
+
num_threads_ = 0;
|
| 411 |
+
PrepareForThreads(num_threads);
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
// Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is
|
| 415 |
+
// a vector with a small number of columns, hence the term "fat vector".
|
| 416 |
+
// 1x1 and 4x4 have specializations for output columns (ie fatness) > 5,
|
| 417 |
+
// and often achieve twice as many GFlops when multiplying a right hand side
|
| 418 |
+
// that has 5 or more columns. (Best is a multiple of 5).
|
| 419 |
+
// 16x1 doesn't have enough registers and just loops over the width 1 kernel.
|
| 420 |
+
//
|
| 421 |
+
// |rhs| and |out| are COLUMN MAJOR.
|
| 422 |
+
|
| 423 |
+
// Fast Tuples WeightType, BiasType, RhsType, OutType are:
|
| 424 |
+
// (float, float, float, float)
|
| 425 |
+
// (bfloat16, float, float, float)
|
| 426 |
+
// and only on ARM64. All other cases use a slow generic implementation.
|
| 427 |
+
template <typename RhsClass, typename BiasClass, typename OutClass,
|
| 428 |
+
typename BiasType = typename BiasClass::value_type,
|
| 429 |
+
typename OutType = typename OutClass::value_type>
|
| 430 |
+
void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
|
| 431 |
+
bool relu = false, int tid = 0,
|
| 432 |
+
SpinBarrier* barrier = nullptr) const {
|
| 433 |
+
static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value,
|
| 434 |
+
"Rhs types must match");
|
| 435 |
+
CHECK_LT(tid, num_threads_);
|
| 436 |
+
CHECK_EQ(rhs.cols(), out->cols());
|
| 437 |
+
CHECK_EQ(rhs.rows(), cols_);
|
| 438 |
+
CHECK_GE(out->rows(), rows_);
|
| 439 |
+
int cols_to_go = out->cols();
|
| 440 |
+
int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid);
|
| 441 |
+
const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_;
|
| 442 |
+
OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid);
|
| 443 |
+
const WeightType* weights_ptr =
|
| 444 |
+
thread_bounds_.OffsetWeights(weights_.data(), tid);
|
| 445 |
+
const DeltaType* delta_ptr =
|
| 446 |
+
thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid);
|
| 447 |
+
int offset = *delta_ptr / sizeof(RhsType);
|
| 448 |
+
rhs_ptr -= offset;
|
| 449 |
+
const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid);
|
| 450 |
+
int assigned_rows =
|
| 451 |
+
thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid);
|
| 452 |
+
const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid);
|
| 453 |
+
|
| 454 |
+
while (cols_to_go > 0) {
|
| 455 |
+
if (block_width_ == 4 && block_height_ == 4) {
|
| 456 |
+
if (cols_to_go >= 5) {
|
| 457 |
+
detail::SpMM5_4x4<WeightType, RhsType, OutType>(
|
| 458 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
| 459 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
| 460 |
+
} else {
|
| 461 |
+
detail::SpMV_4x4<WeightType, RhsType, OutType>(
|
| 462 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
| 463 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
| 464 |
+
}
|
| 465 |
+
} else {
|
| 466 |
+
if (cols_to_go >= 5) {
|
| 467 |
+
detail::SpMM5_1x1<WeightType, RhsType, OutType>(
|
| 468 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
| 469 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
| 470 |
+
} else {
|
| 471 |
+
detail::SpMV_1x1<WeightType, RhsType, OutType>(
|
| 472 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
| 473 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
if (cols_to_go >= 5) {
|
| 478 |
+
cols_to_go -= 5;
|
| 479 |
+
rhs_ptr += rhs.col_stride() * 5;
|
| 480 |
+
out_ptr += out->col_stride() * 5;
|
| 481 |
+
} else {
|
| 482 |
+
cols_to_go--;
|
| 483 |
+
rhs_ptr += rhs.col_stride();
|
| 484 |
+
out_ptr += out->col_stride();
|
| 485 |
+
}
|
| 486 |
+
if (barrier) barrier->barrier();
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
template <typename MVRhsType, typename MVBiasType, typename OutType>
|
| 490 |
+
void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid,
|
| 491 |
+
int replicas, int output_stride, OutType* output) {
|
| 492 |
+
CHECK_LT(tid, num_threads_);
|
| 493 |
+
CHECK_EQ(block_width_, 4) << "Block width must be 4!";
|
| 494 |
+
if (block_height_ == 8) {
|
| 495 |
+
matmul_.MatVec8x4(
|
| 496 |
+
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
|
| 497 |
+
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
|
| 498 |
+
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
|
| 499 |
+
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
|
| 500 |
+
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
|
| 501 |
+
} else {
|
| 502 |
+
CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!";
|
| 503 |
+
matmul_.MatVec4x4(
|
| 504 |
+
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
|
| 505 |
+
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
|
| 506 |
+
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
|
| 507 |
+
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
|
| 508 |
+
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
int rows() const { return rows_; }
|
| 513 |
+
int cols() const { return cols_; }
|
| 514 |
+
int block_height() const { return block_height_; }
|
| 515 |
+
int block_width() const { return block_width_; }
|
| 516 |
+
float sparsity() const { return sparsity_; }
|
| 517 |
+
int num_threads() const { return num_threads_; }
|
| 518 |
+
const ThreadBounds& thread_bounds() const { return thread_bounds_; }
|
| 519 |
+
const CacheAlignedVector<DeltaType>& rhs_indices() const {
|
| 520 |
+
return rhs_indices_;
|
| 521 |
+
}
|
| 522 |
+
const std::string& name() const { return name_; }
|
| 523 |
+
void set_name(const std::string& name) { name_ = name; }
|
| 524 |
+
const std::vector<int>& split_points() const {
|
| 525 |
+
return thread_bounds_.row_starts();
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
std::size_t bytes() const {
|
| 529 |
+
return weights_.size() * sizeof(WeightType) +
|
| 530 |
+
col_deltas_.size() * sizeof(DeltaType) +
|
| 531 |
+
nnz_per_row_.size() * sizeof(int);
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
// Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
|
| 535 |
+
// and then samples from the output (softmax distribution) layer.
|
| 536 |
+
template <typename RhsClass, typename BiasClass, typename OutClass,
|
| 537 |
+
typename BiasType = typename BiasClass::value_type,
|
| 538 |
+
typename OutType = typename OutClass::value_type>
|
| 539 |
+
typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type
|
| 540 |
+
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
|
| 541 |
+
float temperature, int tid, SpinBarrier* barrier,
|
| 542 |
+
std::minstd_rand* gen,
|
| 543 |
+
CacheAlignedVector<float>* scratch) const {
|
| 544 |
+
SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier);
|
| 545 |
+
return out->Sample(temperature, gen, scratch);
|
| 546 |
+
}
|
| 547 |
+
// Fixed32 version.
|
| 548 |
+
template <typename RhsClass, typename BiasClass, typename OutClass,
|
| 549 |
+
typename BiasType = typename BiasClass::value_type,
|
| 550 |
+
typename OutType = typename OutClass::value_type>
|
| 551 |
+
typename std::enable_if<IsFixed32Type<OutType>::value, int>::type
|
| 552 |
+
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
|
| 553 |
+
float temperature, int tid, SpinBarrier* barrier,
|
| 554 |
+
std::minstd_rand* gen,
|
| 555 |
+
CacheAlignedVector<float>* scratch) const {
|
| 556 |
+
// We don't pass the barrier on, as we have more work to do.
|
| 557 |
+
SpMM_bias(rhs, bias, out, /*relu=*/false, tid);
|
| 558 |
+
return out->ReducingSample(gen, scratch, tid, temperature, barrier);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
void Print() const {
|
| 562 |
+
std::cout << "Weights\n";
|
| 563 |
+
weights_.Print();
|
| 564 |
+
std::cout << std::endl;
|
| 565 |
+
std::cout << "Deltas\n";
|
| 566 |
+
col_deltas_.Print();
|
| 567 |
+
std::cout << std::endl;
|
| 568 |
+
std::cout << "nnz\n";
|
| 569 |
+
nnz_per_row_.Print();
|
| 570 |
+
std::cout << std::endl;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
// Split the computation amongst threads by rows based on the number of
|
| 574 |
+
// non zeros, with the addition of a constant to account for the work of the
|
| 575 |
+
// bias and the horizontal add at the end, and also guarantees that each
|
| 576 |
+
// thread writes only whole cache lines, based on the size of OutType.
|
| 577 |
+
// The |cache_line_size| arg is used only for testing. Normally it is provided
|
| 578 |
+
// through the architecture #defines.
|
| 579 |
+
// Each thread gets a contiguous row range (|split_points|).
|
| 580 |
+
// Thread t does rows [ split_points[t], split_points[t + 1] )
|
| 581 |
+
// Each thread also needs to know how many non zeros were before it to skip
|
| 582 |
+
// (|nnz_to_skip|). And finally it also needs to know what the offset into
|
| 583 |
+
// the rhs vector would have been at the split point (|rhs_to_skip|).
|
| 584 |
+
//
|
| 585 |
+
// Some tricky corner cases where the number of non-zeros doesn't split
|
| 586 |
+
// nicely amongst the number of requested threads are not handled and default
|
| 587 |
+
// to one thread; these cases are only going to happen in tests and not in
|
| 588 |
+
// the matrices that correspond in real models.
|
| 589 |
+
//
|
| 590 |
+
// Returns the maximum number of threads that can be used; <= |num_threads|.
|
| 591 |
+
template <typename OutType = int32_t>
|
| 592 |
+
int PrepareForThreads(int num_threads, int cache_line_size = -1) {
|
| 593 |
+
CHECK_GT(num_threads, 0);
|
| 594 |
+
// we've already prepared for this number of threads, nothing to do
|
| 595 |
+
if (num_threads == num_threads_) return num_threads_;
|
| 596 |
+
|
| 597 |
+
num_threads_ = num_threads;
|
| 598 |
+
thread_bounds_.PrepareForThreads(
|
| 599 |
+
block_width_, block_height_, num_threads_,
|
| 600 |
+
ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_,
|
| 601 |
+
nnz_per_row_.data());
|
| 602 |
+
return num_threads_;
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
// Computes and stores the |rhs_indices_| from the |col_deltas_|.
|
| 606 |
+
void ComputeRHSIndices() {
|
| 607 |
+
std::vector<int> cumulative_deltas = CumulativeColDeltas();
|
| 608 |
+
std::vector<DeltaType> rhs_indices(cumulative_deltas.size() +
|
| 609 |
+
reduced_rows_);
|
| 610 |
+
int total_indices = 0;
|
| 611 |
+
int delta_index = 0;
|
| 612 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
| 613 |
+
for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) {
|
| 614 |
+
rhs_indices[total_indices++] =
|
| 615 |
+
cumulative_deltas[delta_index] / block_width_;
|
| 616 |
+
}
|
| 617 |
+
}
|
| 618 |
+
rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices);
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
// Computes and stores the |col_deltas_| from the |rhs_indices_|.
|
| 622 |
+
void ComputeColDeltas() {
|
| 623 |
+
std::vector<int> col_deltas(rhs_indices_.size());
|
| 624 |
+
int prev_index = 0;
|
| 625 |
+
for (int i = 0; i < rhs_indices_.size(); ++i) {
|
| 626 |
+
int offset = rhs_indices_[i] - prev_index;
|
| 627 |
+
prev_index = rhs_indices_[i];
|
| 628 |
+
col_deltas[i] = offset * block_width_ * sizeof(RhsType);
|
| 629 |
+
}
|
| 630 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
// Computes and returns the inclusive prefix sum of the deltas, ie absolute
|
| 634 |
+
// positions.
|
| 635 |
+
std::vector<int> CumulativeColDeltas() const {
|
| 636 |
+
std::vector<int> cum_col_deltas(col_deltas_.size());
|
| 637 |
+
for (int i = 0; i < col_deltas_.size(); ++i) {
|
| 638 |
+
cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType);
|
| 639 |
+
if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1];
|
| 640 |
+
}
|
| 641 |
+
return cum_col_deltas;
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
private:
|
| 645 |
+
constexpr std::size_t FixedParameterSize() const {
|
| 646 |
+
return sizeof(int) // rows
|
| 647 |
+
+ sizeof(int) // cols
|
| 648 |
+
+ sizeof(int) // reduced_rows
|
| 649 |
+
+ sizeof(int) // reduced_cols
|
| 650 |
+
+ sizeof(int) // block_width
|
| 651 |
+
+ sizeof(int) // block_height
|
| 652 |
+
+ sizeof(float) // sparsity
|
| 653 |
+
+ sizeof(int) // col_multiple
|
| 654 |
+
+ sizeof(int) // num_threads_
|
| 655 |
+
+ sizeof(int) // weights_.size()
|
| 656 |
+
+ sizeof(int) // col_deltas_.size()
|
| 657 |
+
+ sizeof(int); // nnz_per_row_.size()
|
| 658 |
+
}
|
| 659 |
+
// Possible block sizes are only those that are supported by the computation
|
| 660 |
+
// default is 1x1, other options are 4x4 and 16x1.
|
| 661 |
+
template <typename InputType>
|
| 662 |
+
void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) {
|
| 663 |
+
const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}};
|
| 664 |
+
int rows = masked_matrix.rows();
|
| 665 |
+
int cols = masked_matrix.cols();
|
| 666 |
+
|
| 667 |
+
for (const auto& block_size : kPreferredOrder) {
|
| 668 |
+
int block_height, block_width;
|
| 669 |
+
std::tie(block_height, block_width) = block_size;
|
| 670 |
+
if (cols % block_width != 0) continue;
|
| 671 |
+
|
| 672 |
+
int reduced_rows = (rows + block_height - 1) / block_height;
|
| 673 |
+
int reduced_cols = cols / block_width;
|
| 674 |
+
|
| 675 |
+
// For each possible block, confirm that it is either all 0s or all 1s.
|
| 676 |
+
bool all_same = true;
|
| 677 |
+
const auto& mask = masked_matrix.mask();
|
| 678 |
+
for (int r = 0; r < reduced_rows; ++r) {
|
| 679 |
+
for (int c = 0; c < reduced_cols; ++c) {
|
| 680 |
+
int val = mask[r * block_height * cols + c * block_width];
|
| 681 |
+
for (int i = 0; i < block_height; ++i) {
|
| 682 |
+
for (int j = 0; j < block_width; ++j) {
|
| 683 |
+
int index = (r * block_height + i) * cols + c * block_width + j;
|
| 684 |
+
if (index < masked_matrix.mask().size()) {
|
| 685 |
+
all_same &= (masked_matrix.mask()[index] == val);
|
| 686 |
+
}
|
| 687 |
+
}
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
// If this block configuration is possible, accept it.
|
| 693 |
+
if (all_same) {
|
| 694 |
+
block_height_ = block_height;
|
| 695 |
+
block_width_ = block_width;
|
| 696 |
+
return;
|
| 697 |
+
}
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
// No large blocks were found, default to 1x1.
|
| 701 |
+
block_height_ = 1;
|
| 702 |
+
block_width_ = 1;
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
// CSR descriptors are for the reduced matrix, weights is the full matrix.
|
| 706 |
+
template <typename InputType>
|
| 707 |
+
void MakeColumnsMultiple(const std::vector<int>& row_offsets,
|
| 708 |
+
std::vector<int>* reduced_mask,
|
| 709 |
+
std::vector<InputType>* weights) {
|
| 710 |
+
if (col_multiple_ > 0) {
|
| 711 |
+
// Make sure each row has a number of columns that is a multiple of
|
| 712 |
+
// |col_multiple|.
|
| 713 |
+
for (int r = 1; r < row_offsets.size(); ++r) {
|
| 714 |
+
int num_row = row_offsets[r] - row_offsets[r - 1];
|
| 715 |
+
int num_needed = col_multiple_ - num_row % col_multiple_;
|
| 716 |
+
if (num_needed < col_multiple_) {
|
| 717 |
+
// Find gaps in the columns where we can insert a column of 0 weights.
|
| 718 |
+
int num_added = 0;
|
| 719 |
+
for (int c = 0; c < reduced_cols_; ++c) {
|
| 720 |
+
if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) {
|
| 721 |
+
(*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1;
|
| 722 |
+
|
| 723 |
+
// Zero out the weights that correspond to this block.
|
| 724 |
+
for (int i = 0; i < block_height_; ++i) {
|
| 725 |
+
for (int j = 0; j < block_width_; ++j) {
|
| 726 |
+
(*weights)[((r - 1) * block_height_ + i) * cols_ +
|
| 727 |
+
block_width_ * c + j] = InputType(0.f);
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
num_added++;
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
if (num_added == num_needed) break;
|
| 734 |
+
}
|
| 735 |
+
}
|
| 736 |
+
}
|
| 737 |
+
}
|
| 738 |
+
}
|
| 739 |
+
|
| 740 |
+
// Given the final dense mask and weights, convert to the compressed
|
| 741 |
+
// block CSR representation.
|
| 742 |
+
template <typename InputType>
|
| 743 |
+
void MaskAndWeightsToCsr(const std::vector<int>& mask,
|
| 744 |
+
const std::vector<InputType>& weights,
|
| 745 |
+
std::vector<int>* nnz_per_row,
|
| 746 |
+
std::vector<int>* col_indices,
|
| 747 |
+
std::vector<WeightType>* weights_csr) {
|
| 748 |
+
std::vector<int> row_offsets = {0};
|
| 749 |
+
int nnz = 0;
|
| 750 |
+
// Standard CSR format.
|
| 751 |
+
if (block_width_ == 1 && block_height_ == 1) {
|
| 752 |
+
for (int r = 0; r < rows_; ++r) {
|
| 753 |
+
for (int c = 0; c < cols_; ++c) {
|
| 754 |
+
if (mask[r * cols_ + c] == 1) {
|
| 755 |
+
nnz++;
|
| 756 |
+
col_indices->push_back(c);
|
| 757 |
+
weights_csr->push_back(WeightType(weights[r * cols_ + c]));
|
| 758 |
+
}
|
| 759 |
+
}
|
| 760 |
+
row_offsets.push_back(nnz);
|
| 761 |
+
}
|
| 762 |
+
} else if (block_width_ == 4 && block_height_ == 4) {
|
| 763 |
+
// Weights are stored contiguously for each block in this case.
|
| 764 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
| 765 |
+
for (int c = 0; c < reduced_cols_; ++c) {
|
| 766 |
+
if (mask[r * reduced_cols_ + c] == 1) {
|
| 767 |
+
col_indices->push_back(c);
|
| 768 |
+
nnz++;
|
| 769 |
+
for (int i = 0; i < block_height_; ++i) {
|
| 770 |
+
for (int j = 0; j < block_width_; ++j) {
|
| 771 |
+
int row_index = (block_height_ * r + i) * cols_;
|
| 772 |
+
int w_index = row_index + block_width_ * c + j;
|
| 773 |
+
WeightType weight = w_index < weights.size()
|
| 774 |
+
? WeightType(weights[w_index])
|
| 775 |
+
: WeightType(0.0f);
|
| 776 |
+
weights_csr->push_back(weight);
|
| 777 |
+
}
|
| 778 |
+
}
|
| 779 |
+
}
|
| 780 |
+
}
|
| 781 |
+
row_offsets.push_back(nnz);
|
| 782 |
+
}
|
| 783 |
+
}
|
| 784 |
+
for (int i = 1; i < row_offsets.size(); ++i)
|
| 785 |
+
nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]);
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
// Returns the number of block rows per cache line. This is the minimum unit
|
| 789 |
+
// into which the calculation is broken for threads.
|
| 790 |
+
template <typename OutType>
|
| 791 |
+
int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const {
|
| 792 |
+
int line_size = kCacheLineSize;
|
| 793 |
+
if (override_cache_line_size >= 1) line_size = override_cache_line_size;
|
| 794 |
+
return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1);
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
int col_multiple_;
|
| 798 |
+
int rows_;
|
| 799 |
+
int cols_;
|
| 800 |
+
int reduced_rows_;
|
| 801 |
+
int reduced_cols_;
|
| 802 |
+
float sparsity_;
|
| 803 |
+
int block_width_;
|
| 804 |
+
int block_height_;
|
| 805 |
+
int num_threads_;
|
| 806 |
+
std::string name_;
|
| 807 |
+
|
| 808 |
+
CacheAlignedVector<WeightType> weights_;
|
| 809 |
+
CacheAlignedVector<DeltaType> col_deltas_;
|
| 810 |
+
CacheAlignedVector<int> nnz_per_row_;
|
| 811 |
+
// |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are
|
| 812 |
+
// always recalculated from serialized data.
|
| 813 |
+
CacheAlignedVector<DeltaType> rhs_indices_;
|
| 814 |
+
Matmul<WeightType, RhsType> matmul_;
|
| 815 |
+
ThreadBounds thread_bounds_;
|
| 816 |
+
static constexpr int kCacheLineSize = 64;
|
| 817 |
+
};
|
| 818 |
+
|
| 819 |
+
// Converts a sparse matrix represented with (|mask|, |weights|, |size|) into
|
| 820 |
+
// the CSR format, and returns that as a serialized string.
|
| 821 |
+
template <typename MaskType>
|
| 822 |
+
std::string ConvertDenseToSparseRepresentation_Int16Deltas(
|
| 823 |
+
const std::vector<MaskType>& mask, const std::vector<float>& weights,
|
| 824 |
+
const int rows, const int cols) {
|
| 825 |
+
MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(),
|
| 826 |
+
weights.data());
|
| 827 |
+
CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
|
| 828 |
+
sparse_masked_weights(masked_weights);
|
| 829 |
+
std::string buffer;
|
| 830 |
+
sparse_masked_weights.WriteToFlatBuffer(&buffer);
|
| 831 |
+
return buffer;
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
} // namespace csrblocksparse
|
| 835 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
|
sparse_matmul/layers/csrblocksparse_test.cc
ADDED
|
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include <array>
|
| 16 |
+
#include <cstdint>
|
| 17 |
+
#include <tuple>
|
| 18 |
+
#include <vector>
|
| 19 |
+
|
| 20 |
+
// Placeholder for get runfiles header.
|
| 21 |
+
#include "absl/status/status.h"
|
| 22 |
+
#include "absl/strings/str_cat.h"
|
| 23 |
+
#include "absl/strings/string_view.h"
|
| 24 |
+
#include "absl/types/span.h"
|
| 25 |
+
#include "gtest/gtest.h"
|
| 26 |
+
#include "include/ghc/filesystem.hpp"
|
| 27 |
+
#include "sparse_matmul/compute/matmul.h"
|
| 28 |
+
#include "sparse_matmul/layers/utils.h"
|
| 29 |
+
#include "sparse_matmul/numerics/test_utils.h"
|
| 30 |
+
#include "sparse_matmul/os/coop_threads.h"
|
| 31 |
+
|
| 32 |
+
namespace csrblocksparse {
|
| 33 |
+
namespace {
|
| 34 |
+
|
| 35 |
+
inline constexpr absl::string_view kTestdataPath = "layers/testdata";
|
| 36 |
+
|
| 37 |
+
TEST(CSRBlockSparseMatrix, FlatBufferSerialization) {
|
| 38 |
+
const int kRows = 8;
|
| 39 |
+
const int kCols = 8;
|
| 40 |
+
std::vector<int> mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
| 41 |
+
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
| 42 |
+
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
|
| 43 |
+
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1};
|
| 44 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
| 45 |
+
values[1] = 2.f;
|
| 46 |
+
values[3] = 3.f;
|
| 47 |
+
values[36] = -1.f;
|
| 48 |
+
values[45] = -2.f;
|
| 49 |
+
|
| 50 |
+
csrblocksparse::CacheAlignedVector<float> bias(kRows);
|
| 51 |
+
csrblocksparse::CacheAlignedVector<float> rhs(kCols);
|
| 52 |
+
csrblocksparse::CacheAlignedVector<float> out_ref(kRows);
|
| 53 |
+
csrblocksparse::CacheAlignedVector<float> out_test(kRows);
|
| 54 |
+
|
| 55 |
+
bias.FillZero();
|
| 56 |
+
rhs.FillOnes();
|
| 57 |
+
|
| 58 |
+
csrblocksparse::MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(),
|
| 59 |
+
values.data());
|
| 60 |
+
|
| 61 |
+
matrix.SpMM_bias(rhs, bias, &out_ref);
|
| 62 |
+
|
| 63 |
+
csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
|
| 64 |
+
block_sparse_matrix(matrix);
|
| 65 |
+
|
| 66 |
+
std::string buffer;
|
| 67 |
+
std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer);
|
| 68 |
+
|
| 69 |
+
csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
|
| 70 |
+
new_block_sparse_matrix(reinterpret_cast<const uint8_t*>(buffer.c_str()),
|
| 71 |
+
num_bytes);
|
| 72 |
+
|
| 73 |
+
new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test);
|
| 74 |
+
|
| 75 |
+
CheckResult(out_ref, out_test, kCols);
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
| 79 |
+
void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height,
|
| 80 |
+
int block_width, float sparsity,
|
| 81 |
+
bool use_relu = false, int num_threads = 1,
|
| 82 |
+
int fatness = 1, bool test_matmul = false) {
|
| 83 |
+
using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
|
| 84 |
+
MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
|
| 85 |
+
block_width);
|
| 86 |
+
matrix.CastWeights<ComputeType>();
|
| 87 |
+
FatCacheAlignedVector<RhsType> rhs(cols, fatness);
|
| 88 |
+
CacheAlignedVector<BiasType> bias(rows);
|
| 89 |
+
FatCacheAlignedVector<OutType> out(rows, fatness);
|
| 90 |
+
|
| 91 |
+
bias.FillRandom();
|
| 92 |
+
rhs.FillRandom();
|
| 93 |
+
out.FillZero();
|
| 94 |
+
FatCacheAlignedVector<OutType> out_reference = out;
|
| 95 |
+
|
| 96 |
+
matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
|
| 97 |
+
|
| 98 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
| 99 |
+
|
| 100 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
| 101 |
+
std::move(sparse_matrix), std::move(bias));
|
| 102 |
+
num_threads = sparse_linear_layer.PrepareForThreads(num_threads);
|
| 103 |
+
|
| 104 |
+
// Checks that the result of applying each thread's portion serially is
|
| 105 |
+
// correct.
|
| 106 |
+
for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
|
| 107 |
+
sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id);
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
CheckResult(out_reference, out, sparse_linear_layer.cols());
|
| 111 |
+
|
| 112 |
+
if (test_matmul) {
|
| 113 |
+
for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
|
| 114 |
+
sparse_linear_layer.MatVec(rhs, use_relu, thread_id,
|
| 115 |
+
/*replicas=*/1, /*output_stride=*/0, &out);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
CheckResult(out_reference, out, sparse_linear_layer.cols());
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// Does:
|
| 123 |
+
// y = Ax + b;
|
| 124 |
+
// x = Ay + b;
|
| 125 |
+
// y = Ax + b;
|
| 126 |
+
//
|
| 127 |
+
// to make sure that dependent multiplies are correct.
|
| 128 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
| 129 |
+
void ThreadBody(
|
| 130 |
+
SpinBarrier* spin_barrier, int tid,
|
| 131 |
+
const SparseLinearLayer<ComputeType, RhsType>& sparse_linear_layer,
|
| 132 |
+
FatCacheAlignedVector<RhsType>* rhs, FatCacheAlignedVector<OutType>* out,
|
| 133 |
+
bool use_relu) {
|
| 134 |
+
sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
|
| 135 |
+
spin_barrier->barrier();
|
| 136 |
+
sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid);
|
| 137 |
+
spin_barrier->barrier();
|
| 138 |
+
sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
| 142 |
+
void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height,
|
| 143 |
+
int block_width, float sparsity,
|
| 144 |
+
bool use_relu = false,
|
| 145 |
+
int num_threads = 1,
|
| 146 |
+
int fatness = 1) {
|
| 147 |
+
typedef typename TypeOfProduct<ComputeType, RhsType>::type BiasType;
|
| 148 |
+
CHECK(rows == cols);
|
| 149 |
+
MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
|
| 150 |
+
block_width);
|
| 151 |
+
matrix.CastWeights<ComputeType>();
|
| 152 |
+
FatCacheAlignedVector<RhsType> rhs(cols, fatness);
|
| 153 |
+
FatCacheAlignedVector<RhsType> rhs_mt(cols, fatness);
|
| 154 |
+
CacheAlignedVector<BiasType> bias(rows);
|
| 155 |
+
FatCacheAlignedVector<OutType> out(rows, fatness);
|
| 156 |
+
|
| 157 |
+
bias.FillOnes();
|
| 158 |
+
rhs.FillOnes();
|
| 159 |
+
rhs_mt.FillOnes();
|
| 160 |
+
out.FillZero();
|
| 161 |
+
FatCacheAlignedVector<OutType> out_reference = out;
|
| 162 |
+
|
| 163 |
+
matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
|
| 164 |
+
matrix.SpMM_bias(out_reference, bias, &rhs, use_relu);
|
| 165 |
+
matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
|
| 166 |
+
|
| 167 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
| 168 |
+
|
| 169 |
+
num_threads = sparse_matrix.PrepareForThreads(num_threads,
|
| 170 |
+
/*cache_line_size=*/1);
|
| 171 |
+
|
| 172 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
| 173 |
+
std::move(sparse_matrix), std::move(bias));
|
| 174 |
+
|
| 175 |
+
csrblocksparse::LaunchOnThreadsWithBarrier(
|
| 176 |
+
num_threads, ThreadBody<ComputeType, RhsType, OutType>,
|
| 177 |
+
sparse_linear_layer, &rhs_mt, &out, use_relu);
|
| 178 |
+
|
| 179 |
+
CheckResult(out_reference, out, cols);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
} // namespace
|
| 183 |
+
|
| 184 |
+
TEST(MaskedSparseCorrectness, HandCoded) {
|
| 185 |
+
const int kRows = 8;
|
| 186 |
+
const int kCols = 8;
|
| 187 |
+
// clang-format off
|
| 188 |
+
std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
|
| 189 |
+
0, 1, 0, 1, 0, 1, 0, 1,
|
| 190 |
+
1, 0, 0, 1, 1, 1, 1, 0,
|
| 191 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 192 |
+
1, 1, 1, 1, 1, 1, 1, 1,
|
| 193 |
+
0, 0, 0, 0, 1, 1, 0, 0,
|
| 194 |
+
1, 1, 0, 0, 1, 1, 0, 0,
|
| 195 |
+
1, 0, 0, 0, 0, 1, 0, 1};
|
| 196 |
+
// clang-format on
|
| 197 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
| 198 |
+
|
| 199 |
+
std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
|
| 200 |
+
|
| 201 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
|
| 202 |
+
CacheAlignedVector<float> rhs(kCols);
|
| 203 |
+
CacheAlignedVector<float> bias(kRows);
|
| 204 |
+
CacheAlignedVector<float> out(kRows);
|
| 205 |
+
|
| 206 |
+
bias.FillOnes();
|
| 207 |
+
rhs.FillOnes();
|
| 208 |
+
out.FillZero();
|
| 209 |
+
|
| 210 |
+
MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
|
| 211 |
+
std::move(bias));
|
| 212 |
+
|
| 213 |
+
masked_linear_layer.SpMM_bias(rhs, &out);
|
| 214 |
+
|
| 215 |
+
for (int i = 0; i < kRows; ++i) {
|
| 216 |
+
EXPECT_EQ(answer[i], out[i]);
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
TEST(MaskedSparseCorrectness, HandCodedFatVector) {
|
| 221 |
+
const int kRows = 8;
|
| 222 |
+
const int kCols = 8;
|
| 223 |
+
// clang-format off
|
| 224 |
+
std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
|
| 225 |
+
0, 1, 0, 1, 0, 1, 0, 1,
|
| 226 |
+
1, 0, 0, 1, 1, 1, 1, 0,
|
| 227 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 228 |
+
1, 1, 1, 1, 1, 1, 1, 1,
|
| 229 |
+
0, 0, 0, 0, 1, 1, 0, 0,
|
| 230 |
+
1, 1, 0, 0, 1, 1, 0, 0,
|
| 231 |
+
1, 0, 0, 0, 0, 1, 0, 1};
|
| 232 |
+
// clang-format on
|
| 233 |
+
|
| 234 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
| 235 |
+
std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
|
| 236 |
+
|
| 237 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
|
| 238 |
+
const int kMaxWidth = 5;
|
| 239 |
+
for (int width = 5; width <= kMaxWidth; ++width) {
|
| 240 |
+
FatCacheAlignedVector<float> rhs(kCols, width);
|
| 241 |
+
CacheAlignedVector<float> bias(kRows);
|
| 242 |
+
FatCacheAlignedVector<float> out(kRows, width);
|
| 243 |
+
|
| 244 |
+
bias.FillOnes();
|
| 245 |
+
rhs.FillOnes();
|
| 246 |
+
out.FillZero();
|
| 247 |
+
|
| 248 |
+
MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
|
| 249 |
+
std::move(bias));
|
| 250 |
+
|
| 251 |
+
masked_linear_layer.SpMM_bias(rhs, &out);
|
| 252 |
+
|
| 253 |
+
for (int i = 0; i < kRows; ++i) {
|
| 254 |
+
for (int width = 0; width < kMaxWidth; ++width) {
|
| 255 |
+
EXPECT_EQ(answer[i], out[i + width * kRows]);
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
}
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
TEST(CsrBlockSparseMatrix, HandCodedMultiThread) {
|
| 262 |
+
const int kRows = 8;
|
| 263 |
+
const int kCols = 8;
|
| 264 |
+
// clang-format off
|
| 265 |
+
std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
|
| 266 |
+
0, 1, 0, 1, 0, 1, 0, 1,
|
| 267 |
+
1, 0, 0, 1, 1, 1, 1, 0,
|
| 268 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 269 |
+
1, 1, 1, 1, 1, 1, 1, 1,
|
| 270 |
+
0, 0, 0, 0, 1, 1, 0, 0,
|
| 271 |
+
1, 1, 0, 0, 1, 1, 0, 0,
|
| 272 |
+
1, 0, 0, 0, 0, 1, 0, 1};
|
| 273 |
+
// clang-format on
|
| 274 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
| 275 |
+
|
| 276 |
+
std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
|
| 277 |
+
|
| 278 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
|
| 279 |
+
CacheAlignedVector<float> rhs(kCols);
|
| 280 |
+
CacheAlignedVector<float> bias(kRows);
|
| 281 |
+
CacheAlignedVector<float> out(kRows);
|
| 282 |
+
|
| 283 |
+
bias.FillOnes();
|
| 284 |
+
rhs.FillOnes();
|
| 285 |
+
out.FillZero();
|
| 286 |
+
|
| 287 |
+
CacheAlignedVector<float> bias_csr = bias;
|
| 288 |
+
|
| 289 |
+
CsrBlockSparseMatrix<bfloat16, float> sparse_matrix(matrix);
|
| 290 |
+
|
| 291 |
+
MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
|
| 292 |
+
std::move(bias));
|
| 293 |
+
|
| 294 |
+
masked_linear_layer.SpMM_bias(rhs, &out);
|
| 295 |
+
|
| 296 |
+
SparseLinearLayer<bfloat16, float> sparse_linear_layer(
|
| 297 |
+
std::move(sparse_matrix), std::move(bias_csr));
|
| 298 |
+
sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1);
|
| 299 |
+
|
| 300 |
+
CacheAlignedVector<float> out_tmp(kRows);
|
| 301 |
+
const bool kUseRelu = false;
|
| 302 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0);
|
| 303 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1);
|
| 304 |
+
|
| 305 |
+
for (int i = 0; i < kRows; ++i) {
|
| 306 |
+
EXPECT_EQ(answer[i], out_tmp[i]);
|
| 307 |
+
}
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
TEST(TestCasts, TestBfloat16) {
|
| 311 |
+
const int kRows = 1000;
|
| 312 |
+
const int kCols = 100;
|
| 313 |
+
const float kSparsity = 0.f;
|
| 314 |
+
|
| 315 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
| 316 |
+
MaskedSparseMatrix<float> matrix_bfloat16(kRows, kCols, matrix.mask().data(),
|
| 317 |
+
matrix.values().data());
|
| 318 |
+
|
| 319 |
+
matrix_bfloat16.CastWeights<bfloat16>();
|
| 320 |
+
|
| 321 |
+
CheckResult(matrix.values(), matrix_bfloat16.values(), kCols);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
TEST(TestCasts, TestFP16) {
|
| 325 |
+
const int kRows = 1000;
|
| 326 |
+
const int kCols = 100;
|
| 327 |
+
const float kSparsity = 0.f;
|
| 328 |
+
|
| 329 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
| 330 |
+
#if !defined __arm__ && !defined __aarch64__
|
| 331 |
+
// Conversion doesn't handle denormals, so flush denormals to zero first.
|
| 332 |
+
for (int i = 0; i < matrix.values().size(); ++i) {
|
| 333 |
+
if (matrix.data()[i] < 1. / static_cast<float>(1 << 14))
|
| 334 |
+
matrix.data()[i] = 0.f;
|
| 335 |
+
}
|
| 336 |
+
#endif
|
| 337 |
+
MaskedSparseMatrix<float> matrix_fp16(kRows, kCols, matrix.mask().data(),
|
| 338 |
+
matrix.values().data());
|
| 339 |
+
|
| 340 |
+
matrix_fp16.CastWeights<csrblocksparse::fp16>();
|
| 341 |
+
|
| 342 |
+
CheckResult(matrix.values(), matrix_fp16.values(), kCols);
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
TEST(TestCasts, TestFixed16) {
|
| 346 |
+
const int kRows = 100000;
|
| 347 |
+
const int kCols = 1;
|
| 348 |
+
const float kSparsity = 0.f;
|
| 349 |
+
|
| 350 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
| 351 |
+
|
| 352 |
+
// Relative error for fixed point is high near 0.
|
| 353 |
+
for (int i = 0; i < matrix.values().size(); ++i) {
|
| 354 |
+
// 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16
|
| 355 |
+
// == 3e-5. 3e-5 / .013 / 2 = 1.1e-3.
|
| 356 |
+
if (std::abs(matrix.data()[i]) < 1.1e-3) {
|
| 357 |
+
matrix.data()[i] = 0.f;
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
MaskedSparseMatrix<float> matrix_fixed16 = matrix;
|
| 362 |
+
|
| 363 |
+
matrix_fixed16.CastWeights<csrblocksparse::fixed16</*ExponentBits=*/0>>();
|
| 364 |
+
|
| 365 |
+
CheckResult(matrix.values(), matrix_fixed16.values(), kCols);
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
TEST(TestCasts, TestFixed32) {
|
| 369 |
+
const int kRows = 100000;
|
| 370 |
+
const int kCols = 1;
|
| 371 |
+
const float kSparsity = 0.f;
|
| 372 |
+
|
| 373 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
| 374 |
+
MaskedSparseMatrix<float> matrix_fixed32 = matrix;
|
| 375 |
+
|
| 376 |
+
matrix_fixed32.CastWeights<csrblocksparse::fixed32</*ExponentBits=*/0>>();
|
| 377 |
+
|
| 378 |
+
CheckResult(matrix.values(), matrix_fixed32.values(), kCols);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
| 382 |
+
void TestSpMM(int block_width, int block_height, int fatness,
|
| 383 |
+
bool test_matmul = false) {
|
| 384 |
+
std::array<bool, 2> use_relu = {false, true};
|
| 385 |
+
std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
|
| 386 |
+
std::vector<std::pair<int, int>> sizes = {{8, 8}, {128, 128}, {128, 64},
|
| 387 |
+
{256, 192}, {512, 512}, {1024, 512},
|
| 388 |
+
{384, 384}, {512, 384}};
|
| 389 |
+
for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) {
|
| 390 |
+
for (const auto& relu : use_relu) {
|
| 391 |
+
for (const auto& sparsity : sparsity_levels) {
|
| 392 |
+
for (const auto& size : sizes) {
|
| 393 |
+
int rows, cols;
|
| 394 |
+
std::tie(rows, cols) = size;
|
| 395 |
+
CorrectnessCheckBlockSpMM<ComputeType, RhsType, OutType>(
|
| 396 |
+
rows, cols, block_height, block_width, sparsity, relu,
|
| 397 |
+
num_threads, fatness, test_matmul);
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
| 405 |
+
void TestSpMM_MultiThread(int block_width, int block_height, int fatness) {
|
| 406 |
+
std::array<bool, 2> use_relu = {false, true};
|
| 407 |
+
std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
|
| 408 |
+
std::vector<std::pair<int, int>> sizes = {
|
| 409 |
+
{48, 48}, {128, 128}, {512, 512}, {384, 384}};
|
| 410 |
+
for (int num_threads = 1; num_threads < 5; ++num_threads) {
|
| 411 |
+
for (const auto& relu : use_relu) {
|
| 412 |
+
for (const auto& sparsity : sparsity_levels) {
|
| 413 |
+
for (const auto& size : sizes) {
|
| 414 |
+
int rows, cols;
|
| 415 |
+
std::tie(rows, cols) = size;
|
| 416 |
+
CorrectnessCheckBlockSpMM_MultiThread<ComputeType, RhsType, OutType>(
|
| 417 |
+
rows, cols, block_height, block_width, sparsity, relu,
|
| 418 |
+
num_threads, fatness);
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
template <typename DataType>
|
| 426 |
+
void TestSumVectors(int start = 0, int end = -1, int size = 6) {
|
| 427 |
+
std::vector<DataType> values;
|
| 428 |
+
std::vector<DataType> answer;
|
| 429 |
+
|
| 430 |
+
for (int i = 1; i < size + 1; ++i) {
|
| 431 |
+
const float x = static_cast<float>(i);
|
| 432 |
+
values.push_back(static_cast<DataType>(x));
|
| 433 |
+
answer.push_back(static_cast<DataType>(x * 2));
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
if (end == -1) {
|
| 437 |
+
end = values.size();
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
csrblocksparse::CacheAlignedVector<DataType> result(values.size());
|
| 441 |
+
csrblocksparse::CacheAlignedVector<DataType> values_aligned(values);
|
| 442 |
+
detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(),
|
| 443 |
+
result.data());
|
| 444 |
+
for (int i = start; i < end; ++i) {
|
| 445 |
+
EXPECT_EQ(static_cast<float>(answer[i]), static_cast<float>(result[i]));
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Generic) {
|
| 450 |
+
TestSumVectors<float>();
|
| 451 |
+
TestSumVectors<float>(1);
|
| 452 |
+
TestSumVectors<float>(1, 4);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) {
|
| 456 |
+
TestSumVectors<csrblocksparse::bfloat16>();
|
| 457 |
+
TestSumVectors<csrblocksparse::bfloat16>(1);
|
| 458 |
+
TestSumVectors<csrblocksparse::bfloat16>(1, 4);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
// For SIMD-optimized SumVectors, the memory of the vector should be at least
|
| 462 |
+
// |kSIMDWidth * sizeof(float)| long, and the start position has to be an
|
| 463 |
+
// aligned memory location. So setting |size| to be 100 to be safe and
|
| 464 |
+
// |start| to be 0 (|start| == 1 is not aligned).
|
| 465 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) {
|
| 466 |
+
TestSumVectors<csrblocksparse::fixed16<8>>(0, -1, 100);
|
| 467 |
+
TestSumVectors<csrblocksparse::fixed16<8>>(0, 4, 100);
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) {
|
| 471 |
+
TestSumVectors<csrblocksparse::fixed32<11>>(0, -1, 100);
|
| 472 |
+
TestSumVectors<csrblocksparse::fixed32<11>>(0, 4, 100);
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) {
|
| 476 |
+
TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/4,
|
| 477 |
+
/*block_height=*/4,
|
| 478 |
+
/*fatness=*/7);
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
// This actually uses multiple threads, and uses the output as the input for
|
| 482 |
+
// multiple steps to test that synchronization and memory visibility is
|
| 483 |
+
// working correctly.Requires square matrices.
|
| 484 |
+
TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) {
|
| 485 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
| 486 |
+
/*block_width=*/4,
|
| 487 |
+
/*block_height=*/4,
|
| 488 |
+
/*fatness=*/1);
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) {
|
| 492 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
| 493 |
+
/*block_width=*/4,
|
| 494 |
+
/*block_height=*/4,
|
| 495 |
+
/*fatness=*/7);
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) {
|
| 499 |
+
TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
|
| 500 |
+
/*block_height=*/1,
|
| 501 |
+
/*fatness=*/1);
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) {
|
| 505 |
+
TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
|
| 506 |
+
/*block_height=*/1,
|
| 507 |
+
/*fatness=*/7);
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
// This actually uses multiple threads, and uses the output as the input for
|
| 511 |
+
// multiple steps to test that synchronization and memory visibility is
|
| 512 |
+
// working correctly.Requires square matrices.
|
| 513 |
+
TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) {
|
| 514 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
| 515 |
+
/*block_width=*/1,
|
| 516 |
+
/*block_height=*/1,
|
| 517 |
+
/*fatness=*/1);
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) {
|
| 521 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
| 522 |
+
/*block_width=*/1,
|
| 523 |
+
/*block_height=*/1,
|
| 524 |
+
/*fatness=*/7);
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) {
|
| 528 |
+
TestSpMM<float, float, float>(/*block_width=*/4,
|
| 529 |
+
/*block_height=*/4,
|
| 530 |
+
/*fatness=*/1,
|
| 531 |
+
/*test_matmul=*/true);
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) {
|
| 535 |
+
TestSpMM<float, float, float>(/*block_width=*/4,
|
| 536 |
+
/*block_height=*/4,
|
| 537 |
+
/*fatness=*/7);
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
// This actually uses multiple threads, and uses the output as the input for
|
| 541 |
+
// multiple steps to test that synchronization and memory visibility is
|
| 542 |
+
// working correctly.Requires square matrices.
|
| 543 |
+
TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) {
|
| 544 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
|
| 545 |
+
/*block_height=*/4,
|
| 546 |
+
/*fatness=*/1);
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) {
|
| 550 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
|
| 551 |
+
/*block_height=*/4,
|
| 552 |
+
/*fatness=*/7);
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) {
|
| 556 |
+
TestSpMM<float, float, float>(/*block_width=*/1,
|
| 557 |
+
/*block_height=*/1,
|
| 558 |
+
/*fatness=*/1);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) {
|
| 562 |
+
TestSpMM<float, float, float>(/*block_width=*/1,
|
| 563 |
+
/*block_height=*/1,
|
| 564 |
+
/*fatness=*/7);
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
// This actually uses multiple threads, and uses the output as the input for
|
| 568 |
+
// multiple steps to test that synchronization and memory visibility is
|
| 569 |
+
// working correctly.Requires square matrices.
|
| 570 |
+
TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) {
|
| 571 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
|
| 572 |
+
/*block_height=*/1,
|
| 573 |
+
/*fatness=*/1);
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) {
|
| 577 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
|
| 578 |
+
/*block_height=*/1,
|
| 579 |
+
/*fatness=*/7);
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) {
|
| 583 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
| 584 |
+
typename csrblocksparse::TypeOfProduct<
|
| 585 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
| 586 |
+
/*block_width=*/4,
|
| 587 |
+
/*block_height=*/4,
|
| 588 |
+
/*fatness=*/1,
|
| 589 |
+
/*test_matmul=*/true);
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) {
|
| 593 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
| 594 |
+
typename csrblocksparse::TypeOfProduct<
|
| 595 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
| 596 |
+
/*block_width=*/4,
|
| 597 |
+
/*block_height=*/4,
|
| 598 |
+
/*fatness=*/7);
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) {
|
| 602 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
| 603 |
+
typename csrblocksparse::TypeOfProduct<
|
| 604 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
| 605 |
+
/*block_width=*/1,
|
| 606 |
+
/*block_height=*/1,
|
| 607 |
+
/*fatness=*/1);
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) {
|
| 611 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
| 612 |
+
typename csrblocksparse::TypeOfProduct<
|
| 613 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
| 614 |
+
/*block_width=*/1,
|
| 615 |
+
/*block_height=*/1,
|
| 616 |
+
/*fatness=*/7);
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) {
|
| 620 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 621 |
+
csrblocksparse::fixed16<8>>(
|
| 622 |
+
/*block_width=*/4,
|
| 623 |
+
/*block_height=*/4,
|
| 624 |
+
/*fatness=*/1,
|
| 625 |
+
/*test_matmul=*/true);
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) {
|
| 629 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 630 |
+
csrblocksparse::fixed16<8>>(
|
| 631 |
+
/*block_width=*/4,
|
| 632 |
+
/*block_height=*/4,
|
| 633 |
+
/*fatness=*/7);
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) {
|
| 637 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 638 |
+
csrblocksparse::fixed16<8>>(
|
| 639 |
+
/*block_width=*/1,
|
| 640 |
+
/*block_height=*/1,
|
| 641 |
+
/*fatness=*/1);
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) {
|
| 645 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 646 |
+
csrblocksparse::fixed16<8>>(
|
| 647 |
+
/*block_width=*/1,
|
| 648 |
+
/*block_height=*/1,
|
| 649 |
+
/*fatness=*/7);
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) {
|
| 653 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 654 |
+
csrblocksparse::fixed32<13>>(
|
| 655 |
+
/*block_width=*/4,
|
| 656 |
+
/*block_height=*/4,
|
| 657 |
+
/*fatness=*/1,
|
| 658 |
+
/*test_matmul=*/true);
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) {
|
| 662 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 663 |
+
csrblocksparse::fixed32<13>>(
|
| 664 |
+
/*block_width=*/4,
|
| 665 |
+
/*block_height=*/4,
|
| 666 |
+
/*fatness=*/7);
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) {
|
| 670 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 671 |
+
csrblocksparse::fixed32<13>>(
|
| 672 |
+
/*block_width=*/1,
|
| 673 |
+
/*block_height=*/1,
|
| 674 |
+
/*fatness=*/1);
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) {
|
| 678 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
| 679 |
+
csrblocksparse::fixed32<13>>(
|
| 680 |
+
/*block_width=*/1,
|
| 681 |
+
/*block_height=*/1,
|
| 682 |
+
/*fatness=*/7);
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) {
|
| 686 |
+
MaskedSparseMatrix<float> matrix(/*rows=*/256, /*cols=*/256,
|
| 687 |
+
/*sparsity=*/0.9, /*block_height=*/4,
|
| 688 |
+
/*block_width=*/4);
|
| 689 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
| 690 |
+
CacheAlignedVector<int16_t> copy_indices = sparse_matrix.rhs_indices();
|
| 691 |
+
sparse_matrix.ComputeColDeltas();
|
| 692 |
+
sparse_matrix.ComputeRHSIndices();
|
| 693 |
+
// They get padded when created, so the newer one could be bigger.
|
| 694 |
+
EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size());
|
| 695 |
+
for (int i = 0; i < copy_indices.size(); ++i) {
|
| 696 |
+
EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i;
|
| 697 |
+
}
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
// Tests that a Layer that is split into 2 by columns (inputs) computes the same
|
| 701 |
+
// result as the original layer.
|
| 702 |
+
TEST(CsrBlockSparseMatrix, SplitByCol) {
|
| 703 |
+
int kRows = 1024;
|
| 704 |
+
int kCols = 1024;
|
| 705 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
|
| 706 |
+
/*block_width=*/4);
|
| 707 |
+
FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
|
| 708 |
+
CacheAlignedVector<float> bias(kRows);
|
| 709 |
+
FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
|
| 710 |
+
FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
|
| 711 |
+
|
| 712 |
+
bias.FillRandom();
|
| 713 |
+
rhs.FillRandom();
|
| 714 |
+
out1.FillZero();
|
| 715 |
+
out2.FillZero();
|
| 716 |
+
FatCacheAlignedVector<float> out_reference = out1;
|
| 717 |
+
|
| 718 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
| 719 |
+
|
| 720 |
+
SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
|
| 721 |
+
std::move(bias));
|
| 722 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 723 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
|
| 724 |
+
/*tid=*/0);
|
| 725 |
+
// Split the layer into 2 parts.
|
| 726 |
+
SparseLinearLayer<float, float> part1, part2;
|
| 727 |
+
sparse_linear_layer.SplitInputs(&part1, &part2);
|
| 728 |
+
part1.PrepareForThreads(1);
|
| 729 |
+
part2.PrepareForThreads(1);
|
| 730 |
+
EXPECT_EQ(kRows, part1.rows());
|
| 731 |
+
EXPECT_EQ(kCols / 2, part1.cols());
|
| 732 |
+
EXPECT_EQ(kRows, part2.rows());
|
| 733 |
+
EXPECT_EQ(kCols / 2, part2.cols());
|
| 734 |
+
MutableVectorView<float> rhs1(&rhs, 0, kCols / 2);
|
| 735 |
+
MutableVectorView<float> rhs2(&rhs, kCols / 2, kCols / 2);
|
| 736 |
+
for (int i = 0; i < kCols / 2; ++i) {
|
| 737 |
+
EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]);
|
| 738 |
+
EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]);
|
| 739 |
+
}
|
| 740 |
+
part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0);
|
| 741 |
+
part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0);
|
| 742 |
+
// Check that out1 + out2 = out_reference.
|
| 743 |
+
for (int i = 0; i < kRows; ++i) {
|
| 744 |
+
EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5)
|
| 745 |
+
<< " i=" << i << " out1=" << out1[i] << " out2=" << out2[i];
|
| 746 |
+
}
|
| 747 |
+
}
|
| 748 |
+
// Tests that a Layer that is split into 2 by rows (outputs) computes the same
|
| 749 |
+
// result as the original layer.
|
| 750 |
+
TEST(CsrBlockSparseMatrix, SplitByRow) {
|
| 751 |
+
int kRows = 1024;
|
| 752 |
+
int kCols = 1024;
|
| 753 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
|
| 754 |
+
/*block_width=*/4);
|
| 755 |
+
FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
|
| 756 |
+
CacheAlignedVector<float> bias(kRows);
|
| 757 |
+
FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
|
| 758 |
+
FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
|
| 759 |
+
|
| 760 |
+
bias.FillRandom();
|
| 761 |
+
rhs.FillRandom();
|
| 762 |
+
out1.FillZero();
|
| 763 |
+
out2.FillZero();
|
| 764 |
+
FatCacheAlignedVector<float> out_reference = out1;
|
| 765 |
+
|
| 766 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
| 767 |
+
|
| 768 |
+
SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
|
| 769 |
+
std::move(bias));
|
| 770 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 771 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
|
| 772 |
+
/*tid=*/0);
|
| 773 |
+
// Split the layer into 2 parts.
|
| 774 |
+
SparseLinearLayer<float, float> part1, part2;
|
| 775 |
+
sparse_linear_layer.SplitOutputs(&part1, &part2);
|
| 776 |
+
part1.PrepareForThreads(1);
|
| 777 |
+
part2.PrepareForThreads(1);
|
| 778 |
+
EXPECT_EQ(kRows / 2, part1.rows());
|
| 779 |
+
EXPECT_EQ(kCols, part1.cols());
|
| 780 |
+
EXPECT_EQ(kRows / 2, part2.rows());
|
| 781 |
+
EXPECT_EQ(kCols, part2.cols());
|
| 782 |
+
MutableVectorView<float> out2a(&out2, 0, kRows / 2);
|
| 783 |
+
MutableVectorView<float> out2b(&out2, kRows / 2, kRows / 2);
|
| 784 |
+
part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0);
|
| 785 |
+
part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0);
|
| 786 |
+
// Check that out2 = out_reference.
|
| 787 |
+
for (int i = 0; i < kRows; ++i) {
|
| 788 |
+
EXPECT_NEAR(out_reference[i], out2[i], 2e-5)
|
| 789 |
+
<< " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i];
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
TEST(CsrBlockSparseMatrix, MutableVectorView) {
|
| 794 |
+
const int kRows = 1024;
|
| 795 |
+
const int kCols = 1024;
|
| 796 |
+
const int kFatness = 2;
|
| 797 |
+
|
| 798 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
| 799 |
+
std::vector<int> mask(kRows * kCols);
|
| 800 |
+
for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2;
|
| 801 |
+
|
| 802 |
+
auto masked_matrix =
|
| 803 |
+
MaskedSparseMatrix<float>(kRows, kCols, mask.data(), values.data());
|
| 804 |
+
auto sparse_matrix = CsrBlockSparseMatrix<bfloat16, float>(masked_matrix);
|
| 805 |
+
FatCacheAlignedVector<float> x(kCols, kFatness);
|
| 806 |
+
x.FillOnes();
|
| 807 |
+
|
| 808 |
+
CacheAlignedVector<float> bias(kRows);
|
| 809 |
+
bias.FillZero();
|
| 810 |
+
|
| 811 |
+
// First check that we can use spans as output. Split a multiplication
|
| 812 |
+
// into upper and lower halves times the full vector:
|
| 813 |
+
// --------------- x t
|
| 814 |
+
// | | x t
|
| 815 |
+
// | | x t
|
| 816 |
+
// --------------- =
|
| 817 |
+
// | | x b
|
| 818 |
+
// | | x b
|
| 819 |
+
// --------------- x b
|
| 820 |
+
|
| 821 |
+
FatCacheAlignedVector<float> out(kRows, kFatness);
|
| 822 |
+
FatCacheAlignedVector<float> out_view(kRows, kFatness);
|
| 823 |
+
|
| 824 |
+
MutableVectorView<float> out_view_top(&out_view, 0, kRows / 2);
|
| 825 |
+
MutableVectorView<float> out_view_bottom(&out_view, kRows / 2, kRows / 2);
|
| 826 |
+
|
| 827 |
+
sparse_matrix.SpMM_bias(x, bias, &out);
|
| 828 |
+
|
| 829 |
+
auto masked_matrix_top =
|
| 830 |
+
MaskedSparseMatrix<float>(kRows / 2, kCols, mask.data(), values.data());
|
| 831 |
+
auto masked_matrix_bottom = MaskedSparseMatrix<float>(
|
| 832 |
+
kRows / 2, kCols, mask.data() + kRows * kCols / 2,
|
| 833 |
+
values.data() + kRows * kCols / 2);
|
| 834 |
+
auto sparse_matrix_top =
|
| 835 |
+
CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_top);
|
| 836 |
+
auto sparse_matrix_bottom =
|
| 837 |
+
CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_bottom);
|
| 838 |
+
|
| 839 |
+
sparse_matrix_top.SpMM_bias(x, bias, &out_view_top);
|
| 840 |
+
sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom);
|
| 841 |
+
|
| 842 |
+
CheckResult(out, out_view, kCols);
|
| 843 |
+
|
| 844 |
+
// Check that we can use a span as an input vector. Multiply upper left
|
| 845 |
+
// portion of the matrix by the top half of the vector.
|
| 846 |
+
// ---------------
|
| 847 |
+
// |oooooo | x q
|
| 848 |
+
// |oooooo | x q
|
| 849 |
+
// | | =
|
| 850 |
+
// | |
|
| 851 |
+
// ---------------
|
| 852 |
+
|
| 853 |
+
auto masked_matrix_quarter = MaskedSparseMatrix<float>(
|
| 854 |
+
kRows / 2, kCols / 2, mask.data(), values.data());
|
| 855 |
+
auto sparse_matrix_quarter =
|
| 856 |
+
CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_quarter);
|
| 857 |
+
|
| 858 |
+
MutableVectorView<float> x_top(&x, 0, kCols / 2);
|
| 859 |
+
FatCacheAlignedVector<float> out_correct(kRows / 2, /*cols=*/2);
|
| 860 |
+
|
| 861 |
+
for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f;
|
| 862 |
+
|
| 863 |
+
MutableVectorView<float> bias_top(&bias, 0, kRows / 2);
|
| 864 |
+
FatCacheAlignedVector<float> out_quarter(kRows / 2, kFatness);
|
| 865 |
+
|
| 866 |
+
sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter);
|
| 867 |
+
|
| 868 |
+
CheckResult(out_correct, out_quarter, kCols / 2);
|
| 869 |
+
}
|
| 870 |
+
|
| 871 |
+
namespace {
|
| 872 |
+
|
| 873 |
+
bool skip_test(const absl::Status& status, absl::string_view msg) {
|
| 874 |
+
if (!status.ok()) {
|
| 875 |
+
LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status;
|
| 876 |
+
return true;
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
return false;
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
} // namespace
|
| 883 |
+
|
| 884 |
+
TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) {
|
| 885 |
+
std::vector<std::string> names = {
|
| 886 |
+
"768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
|
| 887 |
+
"768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
|
| 888 |
+
"768_512_95_4x4_finelogit_", "lyra_conv1d_"};
|
| 889 |
+
const std::string kPath =
|
| 890 |
+
#if defined __arm__ || defined __aarch64__
|
| 891 |
+
"/data/local/tmp/";
|
| 892 |
+
#else
|
| 893 |
+
(ghc::filesystem::current_path() / kTestdataPath).string();
|
| 894 |
+
#endif
|
| 895 |
+
for (auto& layer_name : names) {
|
| 896 |
+
SparseLinearLayer<bfloat16, float> sparse_linear_layer;
|
| 897 |
+
auto status = LoadSparseLayer<bfloat16, float>(layer_name, /*zipped=*/true,
|
| 898 |
+
&sparse_linear_layer, kPath);
|
| 899 |
+
// If the files don't exist on the device we're running on, just skip this
|
| 900 |
+
// test and log that it was skipped.
|
| 901 |
+
if (skip_test(status, layer_name)) return;
|
| 902 |
+
|
| 903 |
+
int rows = sparse_linear_layer.rows();
|
| 904 |
+
int cols = sparse_linear_layer.cols();
|
| 905 |
+
|
| 906 |
+
MaskedLinearLayer<float> masked_linear_layer;
|
| 907 |
+
status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
|
| 908 |
+
&masked_linear_layer, kPath);
|
| 909 |
+
if (skip_test(status, layer_name)) return;
|
| 910 |
+
masked_linear_layer.CastWeights<csrblocksparse::bfloat16>();
|
| 911 |
+
|
| 912 |
+
CacheAlignedVector<float> rhs(cols);
|
| 913 |
+
CacheAlignedVector<float> out_ref(rows);
|
| 914 |
+
CacheAlignedVector<float> out_spmv(rows);
|
| 915 |
+
|
| 916 |
+
rhs.FillRandom();
|
| 917 |
+
out_ref.FillZero();
|
| 918 |
+
out_spmv.FillZero();
|
| 919 |
+
|
| 920 |
+
std::array<bool, 2> use_relus = {false, true};
|
| 921 |
+
for (bool use_relu : use_relus) {
|
| 922 |
+
masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
|
| 923 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
|
| 924 |
+
|
| 925 |
+
CheckResult(out_ref, out_spmv, cols);
|
| 926 |
+
}
|
| 927 |
+
}
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
TEST(CsrBlockSparseMatrix, ModelMatrices_float) {
|
| 931 |
+
std::vector<std::string> names = {
|
| 932 |
+
"768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
|
| 933 |
+
"768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
|
| 934 |
+
"768_512_95_4x4_finelogit_", "lyra_conv1d_"};
|
| 935 |
+
const std::string kPath =
|
| 936 |
+
#if defined __arm__ || defined __aarch64__
|
| 937 |
+
"/data/local/tmp/";
|
| 938 |
+
#else
|
| 939 |
+
(ghc::filesystem::current_path() / kTestdataPath).string();
|
| 940 |
+
#endif
|
| 941 |
+
for (auto& layer_name : names) {
|
| 942 |
+
SparseLinearLayer<float, float> sparse_linear_layer;
|
| 943 |
+
auto status = LoadSparseLayer<float, float>(layer_name, /*zipped=*/true,
|
| 944 |
+
&sparse_linear_layer, kPath);
|
| 945 |
+
// If the files don't exist on the device we're running on, just skip this
|
| 946 |
+
// test and log that it was skipped.
|
| 947 |
+
if (skip_test(status, layer_name)) return;
|
| 948 |
+
|
| 949 |
+
int rows = sparse_linear_layer.rows();
|
| 950 |
+
int cols = sparse_linear_layer.cols();
|
| 951 |
+
|
| 952 |
+
MaskedLinearLayer<float> masked_linear_layer;
|
| 953 |
+
status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
|
| 954 |
+
&masked_linear_layer, kPath);
|
| 955 |
+
if (skip_test(status, layer_name)) return;
|
| 956 |
+
|
| 957 |
+
CacheAlignedVector<float> rhs(cols);
|
| 958 |
+
CacheAlignedVector<float> out_ref(rows);
|
| 959 |
+
CacheAlignedVector<float> out_spmv(rows);
|
| 960 |
+
|
| 961 |
+
rhs.FillRandom();
|
| 962 |
+
out_ref.FillZero();
|
| 963 |
+
out_spmv.FillZero();
|
| 964 |
+
|
| 965 |
+
std::array<bool, 2> use_relus = {false, true};
|
| 966 |
+
for (bool use_relu : use_relus) {
|
| 967 |
+
masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
|
| 968 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
|
| 969 |
+
|
| 970 |
+
CheckResult(out_ref, out_spmv, cols);
|
| 971 |
+
}
|
| 972 |
+
}
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
#undef SKIP_TEST
|
| 976 |
+
|
| 977 |
+
} // namespace csrblocksparse
|
sparse_matmul/layers/errno_mapping.cc
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include "sparse_matmul/layers/errno_mapping.h"
|
| 16 |
+
|
| 17 |
+
#include <string>
|
| 18 |
+
|
| 19 |
+
#include "absl/strings/str_cat.h"
|
| 20 |
+
|
| 21 |
+
namespace csrblocksparse {
|
| 22 |
+
|
| 23 |
+
namespace {
|
| 24 |
+
|
| 25 |
+
absl::StatusCode ErrnoToCode(int error_number) {
|
| 26 |
+
switch (error_number) {
|
| 27 |
+
case 0:
|
| 28 |
+
return absl::StatusCode::kOk;
|
| 29 |
+
case EINVAL: // Invalid argument
|
| 30 |
+
case ENAMETOOLONG: // Filename too long
|
| 31 |
+
case E2BIG: // Argument list too long
|
| 32 |
+
case EDESTADDRREQ: // Destination address required
|
| 33 |
+
case EDOM: // Mathematics argument out of domain of function
|
| 34 |
+
case EFAULT: // Bad address
|
| 35 |
+
case EILSEQ: // Illegal byte sequence
|
| 36 |
+
case ENOPROTOOPT: // Protocol not available
|
| 37 |
+
case ENOSTR: // Not a STREAM
|
| 38 |
+
case ENOTSOCK: // Not a socket
|
| 39 |
+
case ENOTTY: // Inappropriate I/O control operation
|
| 40 |
+
case EPROTOTYPE: // Protocol wrong type for socket
|
| 41 |
+
case ESPIPE: // Invalid seek
|
| 42 |
+
return absl::StatusCode::kInvalidArgument;
|
| 43 |
+
case ETIMEDOUT: // Connection timed out
|
| 44 |
+
case ETIME: // Timer expired
|
| 45 |
+
return absl::StatusCode::kDeadlineExceeded;
|
| 46 |
+
case ENODEV: // No such device
|
| 47 |
+
case ENOENT: // No such file or directory
|
| 48 |
+
#ifdef ENOMEDIUM
|
| 49 |
+
case ENOMEDIUM: // No medium found
|
| 50 |
+
#endif
|
| 51 |
+
case ENXIO: // No such device or address
|
| 52 |
+
case ESRCH: // No such process
|
| 53 |
+
return absl::StatusCode::kNotFound;
|
| 54 |
+
case EEXIST: // File exists
|
| 55 |
+
case EADDRNOTAVAIL: // Address not available
|
| 56 |
+
case EALREADY: // Connection already in progress
|
| 57 |
+
#ifdef ENOTUNIQ
|
| 58 |
+
case ENOTUNIQ: // Name not unique on network
|
| 59 |
+
#endif
|
| 60 |
+
return absl::StatusCode::kAlreadyExists;
|
| 61 |
+
case EPERM: // Operation not permitted
|
| 62 |
+
case EACCES: // Permission denied
|
| 63 |
+
#ifdef ENOKEY
|
| 64 |
+
case ENOKEY: // Required key not available
|
| 65 |
+
#endif
|
| 66 |
+
case EROFS: // Read only file system
|
| 67 |
+
return absl::StatusCode::kPermissionDenied;
|
| 68 |
+
case ENOTEMPTY: // Directory not empty
|
| 69 |
+
case EISDIR: // Is a directory
|
| 70 |
+
case ENOTDIR: // Not a directory
|
| 71 |
+
case EADDRINUSE: // Address already in use
|
| 72 |
+
case EBADF: // Invalid file descriptor
|
| 73 |
+
#ifdef EBADFD
|
| 74 |
+
case EBADFD: // File descriptor in bad state
|
| 75 |
+
#endif
|
| 76 |
+
case EBUSY: // Device or resource busy
|
| 77 |
+
case ECHILD: // No child processes
|
| 78 |
+
case EISCONN: // Socket is connected
|
| 79 |
+
#ifdef EISNAM
|
| 80 |
+
case EISNAM: // Is a named type file
|
| 81 |
+
#endif
|
| 82 |
+
#ifdef ENOTBLK
|
| 83 |
+
case ENOTBLK: // Block device required
|
| 84 |
+
#endif
|
| 85 |
+
case ENOTCONN: // The socket is not connected
|
| 86 |
+
case EPIPE: // Broken pipe
|
| 87 |
+
#ifdef ESHUTDOWN
|
| 88 |
+
case ESHUTDOWN: // Cannot send after transport endpoint shutdown
|
| 89 |
+
#endif
|
| 90 |
+
case ETXTBSY: // Text file busy
|
| 91 |
+
#ifdef EUNATCH
|
| 92 |
+
case EUNATCH: // Protocol driver not attached
|
| 93 |
+
#endif
|
| 94 |
+
return absl::StatusCode::kFailedPrecondition;
|
| 95 |
+
case ENOSPC: // No space left on device
|
| 96 |
+
#ifdef EDQUOT
|
| 97 |
+
case EDQUOT: // Disk quota exceeded
|
| 98 |
+
#endif
|
| 99 |
+
case EMFILE: // Too many open files
|
| 100 |
+
case EMLINK: // Too many links
|
| 101 |
+
case ENFILE: // Too many open files in system
|
| 102 |
+
case ENOBUFS: // No buffer space available
|
| 103 |
+
case ENODATA: // No message is available on the STREAM read queue
|
| 104 |
+
case ENOMEM: // Not enough space
|
| 105 |
+
case ENOSR: // No STREAM resources
|
| 106 |
+
#ifdef EUSERS
|
| 107 |
+
case EUSERS: // Too many users
|
| 108 |
+
#endif
|
| 109 |
+
return absl::StatusCode::kResourceExhausted;
|
| 110 |
+
#ifdef ECHRNG
|
| 111 |
+
case ECHRNG: // Channel number out of range
|
| 112 |
+
#endif
|
| 113 |
+
case EFBIG: // File too large
|
| 114 |
+
case EOVERFLOW: // Value too large to be stored in data type
|
| 115 |
+
case ERANGE: // Result too large
|
| 116 |
+
return absl::StatusCode::kOutOfRange;
|
| 117 |
+
#ifdef ENOPKG
|
| 118 |
+
case ENOPKG: // Package not installed
|
| 119 |
+
#endif
|
| 120 |
+
case ENOSYS: // Function not implemented
|
| 121 |
+
case ENOTSUP: // Operation not supported
|
| 122 |
+
case EAFNOSUPPORT: // Address family not supported
|
| 123 |
+
#ifdef EPFNOSUPPORT
|
| 124 |
+
case EPFNOSUPPORT: // Protocol family not supported
|
| 125 |
+
#endif
|
| 126 |
+
case EPROTONOSUPPORT: // Protocol not supported
|
| 127 |
+
#ifdef ESOCKTNOSUPPORT
|
| 128 |
+
case ESOCKTNOSUPPORT: // Socket type not supported
|
| 129 |
+
#endif
|
| 130 |
+
case EXDEV: // Improper link
|
| 131 |
+
return absl::StatusCode::kUnimplemented;
|
| 132 |
+
case EAGAIN: // Resource temporarily unavailable
|
| 133 |
+
#ifdef ECOMM
|
| 134 |
+
case ECOMM: // Communication error on send
|
| 135 |
+
#endif
|
| 136 |
+
case ECONNREFUSED: // Connection refused
|
| 137 |
+
case ECONNABORTED: // Connection aborted
|
| 138 |
+
case ECONNRESET: // Connection reset
|
| 139 |
+
case EINTR: // Interrupted function call
|
| 140 |
+
#ifdef EHOSTDOWN
|
| 141 |
+
case EHOSTDOWN: // Host is down
|
| 142 |
+
#endif
|
| 143 |
+
case EHOSTUNREACH: // Host is unreachable
|
| 144 |
+
case ENETDOWN: // Network is down
|
| 145 |
+
case ENETRESET: // Connection aborted by network
|
| 146 |
+
case ENETUNREACH: // Network unreachable
|
| 147 |
+
case ENOLCK: // No locks available
|
| 148 |
+
case ENOLINK: // Link has been severed
|
| 149 |
+
#ifdef ENONET
|
| 150 |
+
case ENONET: // Machine is not on the network
|
| 151 |
+
#endif
|
| 152 |
+
return absl::StatusCode::kUnavailable;
|
| 153 |
+
case EDEADLK: // Resource deadlock avoided
|
| 154 |
+
#ifdef ESTALE
|
| 155 |
+
case ESTALE: // Stale file handle
|
| 156 |
+
#endif
|
| 157 |
+
return absl::StatusCode::kAborted;
|
| 158 |
+
case ECANCELED: // Operation cancelled
|
| 159 |
+
return absl::StatusCode::kCancelled;
|
| 160 |
+
default:
|
| 161 |
+
return absl::StatusCode::kUnknown;
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
// POSIX `strerror_r()` returns `int`.
|
| 166 |
+
ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(int result, const char* buffer,
|
| 167 |
+
int error_code) {
|
| 168 |
+
if (ABSL_PREDICT_FALSE(result != 0)) {
|
| 169 |
+
return absl::StrCat("Unknown error ", error_code);
|
| 170 |
+
}
|
| 171 |
+
return buffer;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// GNU `strerror_r()` returns `char*`.
|
| 175 |
+
ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(char* result,
|
| 176 |
+
const char* buffer,
|
| 177 |
+
int error_code) {
|
| 178 |
+
return result;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
std::string StrError(int error_code) {
|
| 182 |
+
char message[256];
|
| 183 |
+
return StrErrorResult(strerror_r(error_code, message, sizeof(message)),
|
| 184 |
+
message, error_code);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
} // namespace
|
| 188 |
+
|
| 189 |
+
absl::Status ErrnoToCanonicalStatus(int error_number,
|
| 190 |
+
absl::string_view message) {
|
| 191 |
+
return absl::Status(ErrnoToCode(error_number),
|
| 192 |
+
absl::StrCat(message, ": ", StrError(error_number)));
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
} // namespace csrblocksparse
|
sparse_matmul/layers/errno_mapping.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
|
| 16 |
+
#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
|
| 17 |
+
|
| 18 |
+
#include "absl/status/status.h"
|
| 19 |
+
#include "absl/strings/string_view.h"
|
| 20 |
+
|
| 21 |
+
namespace csrblocksparse {
|
| 22 |
+
|
| 23 |
+
// Converts |error_number| value to absl::Status.
|
| 24 |
+
absl::Status ErrnoToCanonicalStatus(int error_number,
|
| 25 |
+
absl::string_view message);
|
| 26 |
+
|
| 27 |
+
} // namespace csrblocksparse
|
| 28 |
+
|
| 29 |
+
#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
|
sparse_matmul/layers/masked_sparse_matrix.h
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
|
| 19 |
+
|
| 20 |
+
#include <algorithm>
|
| 21 |
+
#include <cstdio>
|
| 22 |
+
#include <numeric>
|
| 23 |
+
#include <vector>
|
| 24 |
+
|
| 25 |
+
#include "absl/strings/str_format.h"
|
| 26 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
| 27 |
+
|
| 28 |
+
namespace csrblocksparse {
|
| 29 |
+
|
| 30 |
+
// MaskedSparseMatrix serves two purposes:
|
| 31 |
+
// 1) It is useful as a reference implementation of SpMV for correctness
|
| 32 |
+
// checking the much more complicated implementations in CSRBlockSparseMatrix
|
| 33 |
+
// 2) This is the format that sparse matrices are represented after pruning
|
| 34 |
+
// in TF. This class provides a bridge to getting these parameters into
|
| 35 |
+
// a compressed form suitable for computation and serialization.
|
| 36 |
+
//
|
| 37 |
+
// MaskedSparseMatrix<float> matrix(rows, cols, mask_from_tf, values_from_tf);
|
| 38 |
+
// CSRBlockSparseMatrix<float, bfloat16, int16_t> csr_matrix(matrix);
|
| 39 |
+
// csr_matrix.Multiply(rhs, bias, &out);
|
| 40 |
+
template <typename T>
|
| 41 |
+
class MaskedSparseMatrix {
|
| 42 |
+
public:
|
| 43 |
+
MaskedSparseMatrix() {}
|
| 44 |
+
|
| 45 |
+
// Construct a MaskedSparseMatrix of the given size, sparsity and block size.
|
| 46 |
+
// This is mainly useful for testing.
|
| 47 |
+
MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1,
|
| 48 |
+
int block_width = 1, float constant = 1.f,
|
| 49 |
+
bool random = true)
|
| 50 |
+
: rows_(rows), cols_(cols), sparsity_(sparsity) {
|
| 51 |
+
CHECK_EQ(rows % block_height, 0);
|
| 52 |
+
CHECK_EQ(cols % block_width, 0);
|
| 53 |
+
|
| 54 |
+
init(sparsity, block_height, block_width, constant, random);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// Construct from an existing mask and values (most likely from a TF model).
|
| 58 |
+
template <typename MaskType>
|
| 59 |
+
MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values)
|
| 60 |
+
: rows_(rows), cols_(cols) {
|
| 61 |
+
mask_.resize(rows * cols);
|
| 62 |
+
values_.resize(rows * cols);
|
| 63 |
+
std::copy_n(mask, rows * cols, mask_.begin());
|
| 64 |
+
std::copy_n(values, rows * cols, values_.begin());
|
| 65 |
+
sparsity_ =
|
| 66 |
+
1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size();
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const std::vector<int>& mask() const { return mask_; }
|
| 70 |
+
const std::vector<T>& values() const { return values_; }
|
| 71 |
+
T* data() { return values_.data(); }
|
| 72 |
+
const T* data() const { return values_.data(); }
|
| 73 |
+
|
| 74 |
+
int rows() const { return rows_; }
|
| 75 |
+
int cols() const { return cols_; }
|
| 76 |
+
float sparsity() const { return sparsity_; }
|
| 77 |
+
|
| 78 |
+
void Print() const {
|
| 79 |
+
absl::PrintF("-------Values---------\n");
|
| 80 |
+
for (int r = 0; r < rows_; ++r) {
|
| 81 |
+
for (int c = 0; c < cols_; ++c) {
|
| 82 |
+
absl::PrintF("%+6.3f ", static_cast<float>(values_[r * cols_ + c]));
|
| 83 |
+
}
|
| 84 |
+
absl::PrintF("\n");
|
| 85 |
+
}
|
| 86 |
+
absl::PrintF("-------Mask---------\n");
|
| 87 |
+
for (int r = 0; r < rows_; ++r) {
|
| 88 |
+
for (int c = 0; c < cols_; ++c) {
|
| 89 |
+
printf("%2d ", mask_[r * cols_ + c]);
|
| 90 |
+
}
|
| 91 |
+
absl::PrintF("\n");
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// This routine is useful for rounding the possibly higher precision values
|
| 96 |
+
// stored in this class to a lower precision, so that correctness checks
|
| 97 |
+
// between this class and CSRBlockSparseMatrix can have a tighter tolerance.
|
| 98 |
+
template <typename U>
|
| 99 |
+
void CastWeights() {
|
| 100 |
+
for (int i = 0; i < values_.size(); ++i) {
|
| 101 |
+
values_[i] = static_cast<T>(U(values_[i]));
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// Only meant for correctness checking.
|
| 106 |
+
// RhsClassType is meant to be either CacheAlignedVector OR
|
| 107 |
+
// FatCacheAlignedVector.
|
| 108 |
+
// The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR.
|
| 109 |
+
// |bias| is broadcast if |rhs| has more than one column.
|
| 110 |
+
template <typename RhsClassType, typename BiasType, typename OutClassType,
|
| 111 |
+
typename RhsType = typename RhsClassType::value_type,
|
| 112 |
+
typename OutType = typename OutClassType::value_type>
|
| 113 |
+
void SpMM_bias(const RhsClassType& rhs,
|
| 114 |
+
const CacheAlignedVector<BiasType>& bias, OutClassType* out,
|
| 115 |
+
bool relu = false) {
|
| 116 |
+
for (int r = 0; r < rows_; ++r) {
|
| 117 |
+
for (int n = 0; n < rhs.cols(); ++n) {
|
| 118 |
+
float sum = 0.f;
|
| 119 |
+
const RhsType* rhs_ptr = rhs.data() + n * rhs.rows();
|
| 120 |
+
OutType* out_ptr = out->data() + n * out->rows();
|
| 121 |
+
const int* mask_ptr = mask_.data() + r * cols_;
|
| 122 |
+
const T* value_ptr = values_.data() + r * cols_;
|
| 123 |
+
for (int c = 0; c < cols_; ++c) {
|
| 124 |
+
sum += mask_ptr[c] * static_cast<float>(value_ptr[c]) *
|
| 125 |
+
static_cast<float>(rhs_ptr[c]);
|
| 126 |
+
}
|
| 127 |
+
out_ptr[r] = static_cast<OutType>(
|
| 128 |
+
relu ? std::max(sum + static_cast<float>(bias[r]), 0.f)
|
| 129 |
+
: sum + static_cast<float>(bias[r]));
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
private:
|
| 135 |
+
// Generate a random matrix with the specified sparsity.
|
| 136 |
+
// Useful for testing.
|
| 137 |
+
void init(float sparsity, int block_height, int block_width, float constant,
|
| 138 |
+
bool random = true) {
|
| 139 |
+
int reduced_rows = rows_ / block_height;
|
| 140 |
+
int reduced_cols = cols_ / block_width;
|
| 141 |
+
mask_.resize(rows_ * cols_, 0);
|
| 142 |
+
|
| 143 |
+
// Fill with non-zero value to make sure masking works.
|
| 144 |
+
values_.resize(rows_ * cols_, static_cast<T>(2.f));
|
| 145 |
+
|
| 146 |
+
std::mt19937 generator(0);
|
| 147 |
+
std::uniform_real_distribution<float> dist_sparsity;
|
| 148 |
+
std::uniform_real_distribution<float> dist_value(-1.f, 1.f);
|
| 149 |
+
int nnz = 0;
|
| 150 |
+
while (nnz == 0) {
|
| 151 |
+
for (int r = 0; r < reduced_rows; ++r) {
|
| 152 |
+
for (int c = 0; c < reduced_cols; ++c) {
|
| 153 |
+
if (dist_sparsity(generator) > sparsity) {
|
| 154 |
+
nnz++;
|
| 155 |
+
for (int i = 0; i < block_height; ++i) {
|
| 156 |
+
for (int j = 0; j < block_width; ++j) {
|
| 157 |
+
mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1;
|
| 158 |
+
values_[(r * block_height + i) * cols_ + block_width * c + j] =
|
| 159 |
+
static_cast<T>(random ? dist_value(generator) : constant);
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
std::vector<int> mask_;
|
| 169 |
+
std::vector<T> values_;
|
| 170 |
+
int rows_;
|
| 171 |
+
int cols_;
|
| 172 |
+
float sparsity_;
|
| 173 |
+
};
|
| 174 |
+
|
| 175 |
+
template <typename T>
|
| 176 |
+
class MaskedLinearLayer {
|
| 177 |
+
public:
|
| 178 |
+
MaskedLinearLayer(MaskedSparseMatrix<T>&& weights,
|
| 179 |
+
CacheAlignedVector<T>&& bias)
|
| 180 |
+
: weights_(std::move(weights)), bias_(std::move(bias)) {}
|
| 181 |
+
|
| 182 |
+
MaskedLinearLayer() {}
|
| 183 |
+
|
| 184 |
+
template <typename U>
|
| 185 |
+
void CastWeights() {
|
| 186 |
+
weights_.template CastWeights<U>();
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// Does Ax + b where A is a masked sparse ROW MAJOR matrix and
|
| 190 |
+
// x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
|
| 191 |
+
// broadcast is rhs has more than one column.
|
| 192 |
+
template <typename FatVector>
|
| 193 |
+
void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) {
|
| 194 |
+
static_assert(std::is_same<typename FatVector::value_type, T>::value,
|
| 195 |
+
"FatVector value_type must match masked_linear_layer type");
|
| 196 |
+
weights_.SpMM_bias(rhs, bias_, out, relu);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
private:
|
| 200 |
+
MaskedSparseMatrix<T> weights_;
|
| 201 |
+
CacheAlignedVector<T> bias_;
|
| 202 |
+
};
|
| 203 |
+
|
| 204 |
+
} // namespace csrblocksparse
|
| 205 |
+
|
| 206 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
|
sparse_matmul/layers/read_array_ifstream.h
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
// Low-level array reading function using std::ifstream.
|
| 18 |
+
|
| 19 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
|
| 20 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
|
| 21 |
+
|
| 22 |
+
#include <cstdint>
|
| 23 |
+
#include <fstream>
|
| 24 |
+
#include <sstream>
|
| 25 |
+
#include <string>
|
| 26 |
+
|
| 27 |
+
#include "absl/status/status.h"
|
| 28 |
+
#include "absl/strings/substitute.h"
|
| 29 |
+
#include "include/ghc/filesystem.hpp"
|
| 30 |
+
|
| 31 |
+
namespace csrblocksparse {
|
| 32 |
+
namespace detail {
|
| 33 |
+
|
| 34 |
+
template <typename T>
|
| 35 |
+
absl::Status ReadArrayIfstream(const std::string& file_name,
|
| 36 |
+
const std::string& path, std::vector<T>* array,
|
| 37 |
+
int64_t* length) {
|
| 38 |
+
ghc::filesystem::path complete_path(path);
|
| 39 |
+
complete_path /= file_name;
|
| 40 |
+
std::ifstream in_stream(complete_path.u8string(), std::ios::binary);
|
| 41 |
+
if (!in_stream.is_open()) {
|
| 42 |
+
return absl::UnknownError(
|
| 43 |
+
absl::Substitute("Error opening $0", complete_path.string()));
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
std::stringstream buffer;
|
| 47 |
+
buffer << in_stream.rdbuf();
|
| 48 |
+
if (buffer.str().empty()) {
|
| 49 |
+
LOG(ERROR) << "File " << complete_path << " was empty.";
|
| 50 |
+
return absl::UnknownError(
|
| 51 |
+
absl::Substitute("File $0 was empty", complete_path.string()));
|
| 52 |
+
}
|
| 53 |
+
std::string contents = buffer.str();
|
| 54 |
+
*length = contents.length();
|
| 55 |
+
int64_t elem = (*length + sizeof(T) - 1) / sizeof(T);
|
| 56 |
+
array->resize(elem);
|
| 57 |
+
std::move(contents.begin(), contents.end(),
|
| 58 |
+
reinterpret_cast<char*>(array->data()));
|
| 59 |
+
|
| 60 |
+
return absl::OkStatus();
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
} // namespace detail
|
| 64 |
+
} // namespace csrblocksparse
|
| 65 |
+
|
| 66 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
|
sparse_matmul/layers/sparse_linear_layer.h
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright 2021 Google LLC
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
|
| 18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
|
| 19 |
+
|
| 20 |
+
#include <cstdint>
|
| 21 |
+
|
| 22 |
+
#include "absl/memory/memory.h"
|
| 23 |
+
#include "glog/logging.h"
|
| 24 |
+
#include "sparse_matmul/layers/csr_blocksparse_matrix.h"
|
| 25 |
+
#include "sparse_matmul/layers/masked_sparse_matrix.h"
|
| 26 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
| 27 |
+
#include "sparse_matmul/os/coop_threads.h"
|
| 28 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
| 29 |
+
|
| 30 |
+
namespace csrblocksparse {
|
| 31 |
+
|
| 32 |
+
template <typename WeightType, typename RhsType,
|
| 33 |
+
typename BiasType = typename TypeOfProduct<WeightType, RhsType>::type,
|
| 34 |
+
typename DeltaType = int16_t>
|
| 35 |
+
class SparseLinearLayer {
|
| 36 |
+
public:
|
| 37 |
+
SparseLinearLayer() {}
|
| 38 |
+
|
| 39 |
+
SparseLinearLayer(CsrBlockSparseMatrix<WeightType, RhsType>&& sparse_matrix,
|
| 40 |
+
CacheAlignedVector<BiasType>&& bias)
|
| 41 |
+
: sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) {
|
| 42 |
+
CHECK_EQ(sparse_matrix_.rows(), full_bias_.size());
|
| 43 |
+
// Some kernels expect that the bias is divided by 4, so we store a second
|
| 44 |
+
// copy of a quarter of the bias.
|
| 45 |
+
// TODO(b/189958858): Remove the quartered bias if it can be done without
|
| 46 |
+
// loss of speed, and rename the |full_bias_| member back to |bias_|.
|
| 47 |
+
bias_ = full_bias_;
|
| 48 |
+
for (int i = 0; i < bias_.size(); ++i) {
|
| 49 |
+
bias_[i] = static_cast<BiasType>(.25f * static_cast<float>(bias_[i]));
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
SparseLinearLayer(
|
| 53 |
+
const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
|
| 54 |
+
*this = src;
|
| 55 |
+
}
|
| 56 |
+
SparseLinearLayer& operator=(
|
| 57 |
+
const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
|
| 58 |
+
sparse_matrix_ = src.sparse_matrix_;
|
| 59 |
+
bias_ = src.bias_;
|
| 60 |
+
full_bias_ = src.full_bias_;
|
| 61 |
+
mid_output_ = src.mid_output_;
|
| 62 |
+
thread_layers_ = src.thread_layers_;
|
| 63 |
+
num_threads_ = src.num_threads_;
|
| 64 |
+
if (src.split_pc_) {
|
| 65 |
+
split_pc_ = absl::make_unique<ProducerConsumer>(
|
| 66 |
+
src.split_pc_->num_producers(), src.split_pc_->num_consumers());
|
| 67 |
+
}
|
| 68 |
+
return *this;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// Does Ax + b where A is a block sparse compressed sparse row matrix and
|
| 72 |
+
// x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
|
| 73 |
+
// broadcast if rhs has more than one column.
|
| 74 |
+
template <typename RhsClassType, typename OutType>
|
| 75 |
+
void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false,
|
| 76 |
+
int tid = 0, SpinBarrier* barrier = nullptr) const {
|
| 77 |
+
static_assert(
|
| 78 |
+
std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
|
| 79 |
+
sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier);
|
| 80 |
+
}
|
| 81 |
+
// Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
|
| 82 |
+
// and then samples from the output (softmax distribution) layer.
|
| 83 |
+
template <typename RhsClassType, typename OutType>
|
| 84 |
+
int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature,
|
| 85 |
+
int tid, SpinBarrier* barrier, std::minstd_rand* gen,
|
| 86 |
+
CacheAlignedVector<float>* scratch) const {
|
| 87 |
+
static_assert(
|
| 88 |
+
std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
|
| 89 |
+
return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid,
|
| 90 |
+
barrier, gen, scratch);
|
| 91 |
+
}
|
| 92 |
+
template <typename RhsClassType, typename OutType>
|
| 93 |
+
void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas,
|
| 94 |
+
int output_stride, OutType* output,
|
| 95 |
+
SpinBarrier* barrier = nullptr) {
|
| 96 |
+
static_assert(
|
| 97 |
+
std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
|
| 98 |
+
#ifdef __AVX2__
|
| 99 |
+
if (block_width() == 4 && (block_height() == 4 || block_height() == 8) &&
|
| 100 |
+
!IsCustomFloatType<WeightType>::value) {
|
| 101 |
+
if (!IsSplit()) {
|
| 102 |
+
sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu,
|
| 103 |
+
tid, replicas, output_stride, output->data());
|
| 104 |
+
if (barrier != nullptr) barrier->barrier();
|
| 105 |
+
return;
|
| 106 |
+
}
|
| 107 |
+
// NOTE: Until the quartered bias is removed it is a bad idea to split
|
| 108 |
+
// for ARM in the same way, as we would have to quarter the output of
|
| 109 |
+
// the first part of the split before running the second part.
|
| 110 |
+
// Signal completion of the previous MatVec.
|
| 111 |
+
split_pc_->produce();
|
| 112 |
+
PartLinearLayer& thread_part = thread_layers_[tid];
|
| 113 |
+
auto offset_output =
|
| 114 |
+
sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid);
|
| 115 |
+
auto mid_output =
|
| 116 |
+
sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid);
|
| 117 |
+
auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput(
|
| 118 |
+
mid_output_.cast_data(), tid);
|
| 119 |
+
// We can continue to consume the data that this thread produced and
|
| 120 |
+
// compute just the |self_matrix| part.
|
| 121 |
+
// No |relu| or |replicas|, as this is only a partial matmul.
|
| 122 |
+
// |tid| is always zero because the matrix has been split by tid.
|
| 123 |
+
thread_part.self_matrix.MatVec(
|
| 124 |
+
rhs.cast_data(), thread_part.full_bias.cast_data(), /*relu=*/false,
|
| 125 |
+
/*tid=*/0, /*replicas=*/1, output_stride, mid_output);
|
| 126 |
+
// We have to wait for the other threads to finish working on the previous
|
| 127 |
+
// MatMul before consuming the rest of |rhs|.
|
| 128 |
+
split_pc_->consume();
|
| 129 |
+
thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu,
|
| 130 |
+
/*tid=*/0, replicas, output_stride,
|
| 131 |
+
offset_output);
|
| 132 |
+
return;
|
| 133 |
+
}
|
| 134 |
+
#endif
|
| 135 |
+
DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API";
|
| 136 |
+
if (IsSplit()) {
|
| 137 |
+
// Generics aren't setup to use a split matrix. This will be inefficient.
|
| 138 |
+
split_pc_->produce();
|
| 139 |
+
split_pc_->consume();
|
| 140 |
+
}
|
| 141 |
+
if (block_height() == 8) {
|
| 142 |
+
// We are currently forced to use MatVec generics for this case.
|
| 143 |
+
LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!";
|
| 144 |
+
sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid,
|
| 145 |
+
replicas, output_stride, output->data());
|
| 146 |
+
if (barrier != nullptr) barrier->barrier();
|
| 147 |
+
} else {
|
| 148 |
+
sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
int rows() const { return sparse_matrix_.rows(); }
|
| 153 |
+
int cols() const { return sparse_matrix_.cols(); }
|
| 154 |
+
float sparsity() const { return sparse_matrix_.sparsity(); }
|
| 155 |
+
int block_width() const { return sparse_matrix_.block_width(); }
|
| 156 |
+
int block_height() const { return sparse_matrix_.block_height(); }
|
| 157 |
+
int num_threads() const { return sparse_matrix_.num_threads(); }
|
| 158 |
+
const CacheAlignedVector<BiasType>& bias() const { return bias_; }
|
| 159 |
+
const std::vector<int>& split_points() const {
|
| 160 |
+
return sparse_matrix_.split_points();
|
| 161 |
+
}
|
| 162 |
+
bool IsSplit() const {
|
| 163 |
+
return !thread_layers_.empty() && split_pc_ != nullptr;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); }
|
| 167 |
+
void Print() const {
|
| 168 |
+
printf("Matrix\n");
|
| 169 |
+
sparse_matrix_.Print();
|
| 170 |
+
printf("Bias\n");
|
| 171 |
+
bias_.Print();
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
// Combines adjacent row blocks, doubling the block height.
|
| 175 |
+
// This necessarily involves adding zero weights where the blocks don't align
|
| 176 |
+
// across adjacent pairs of rows, so use with caution, as the resulting matrix
|
| 177 |
+
// is most likely to run slower if very sparse to begin with.
|
| 178 |
+
// In the few cases where the blocks do mostly align, the resulting matmul
|
| 179 |
+
// could be much faster, as the number of reads of the rhs will be halved.
|
| 180 |
+
void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); }
|
| 181 |
+
|
| 182 |
+
// Cache_line_size is provided only for testing. Normally uses a value for
|
| 183 |
+
// the current architecture.
|
| 184 |
+
int PrepareForThreads(int num_threads, int cache_line_size = -1) {
|
| 185 |
+
num_threads_ = num_threads;
|
| 186 |
+
if (num_threads_ > 1) {
|
| 187 |
+
split_pc_ =
|
| 188 |
+
absl::make_unique<ProducerConsumer>(num_threads_, num_threads_);
|
| 189 |
+
} else {
|
| 190 |
+
split_pc_.reset(nullptr);
|
| 191 |
+
}
|
| 192 |
+
return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size);
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// Partitions the matrix into pieces by thread.
|
| 196 |
+
// In this matrix, we can go ahead and calculate the part that only depends
|
| 197 |
+
// on rhs inputs that were generated by this thread in the previous matvec,
|
| 198 |
+
// without having to use any thread synchronization, and only after that do we
|
| 199 |
+
// have to wait for the other threads to finish the previous matvec.
|
| 200 |
+
// So we split the matrix using the |split_points| from the previous matrix
|
| 201 |
+
// into 2 * |num_threads_| pieces: self and other for each thread, being the
|
| 202 |
+
// parts that can be calculated before and after the other threads have
|
| 203 |
+
// completed their calculation of the previous matvec.
|
| 204 |
+
// We then have to use a ProducerConsumer lock instead of a SpinBarrier to
|
| 205 |
+
// synchronize the data produced by the other threads.
|
| 206 |
+
void SliceForThreads(const std::vector<int>& split_points) {
|
| 207 |
+
thread_layers_.clear();
|
| 208 |
+
thread_layers_.reserve(num_threads_);
|
| 209 |
+
LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for "
|
| 210 |
+
<< num_threads_ << " threads";
|
| 211 |
+
for (int tid = 0; tid < num_threads_; ++tid) {
|
| 212 |
+
thread_layers_.emplace_back(
|
| 213 |
+
sparse_matrix_, full_bias_, bias_, tid,
|
| 214 |
+
split_points[tid] * sparse_matrix_.block_height(),
|
| 215 |
+
split_points[tid + 1] * sparse_matrix_.block_height());
|
| 216 |
+
}
|
| 217 |
+
mid_output_ =
|
| 218 |
+
std::move(csrblocksparse::CacheAlignedVector<BiasType>(rows()));
|
| 219 |
+
mid_output_.FillZero();
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
// Splits the layer by inputs into 2 equal pieces. Each of the resulting
|
| 223 |
+
// layers should be computed independently on the first and second halves of
|
| 224 |
+
// the inputs respectively and the results added to achieve the same effect
|
| 225 |
+
// as the original layer.
|
| 226 |
+
void SplitInputs(
|
| 227 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
|
| 228 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
|
| 229 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
|
| 230 |
+
sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2));
|
| 231 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix2(
|
| 232 |
+
sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2,
|
| 233 |
+
sparse_matrix_.cols()));
|
| 234 |
+
*part1 =
|
| 235 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
| 236 |
+
std::move(matrix1),
|
| 237 |
+
std::move(CacheAlignedVector<BiasType>(full_bias_))));
|
| 238 |
+
CacheAlignedVector<BiasType> bias2(sparse_matrix_.rows());
|
| 239 |
+
bias2.FillZero();
|
| 240 |
+
*part2 =
|
| 241 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
| 242 |
+
std::move(matrix2), std::move(bias2)));
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
// Splits the layer by outputs into 2 equal pieces. Each of the resulting
|
| 246 |
+
// layers should be computed independently on the full inputs and the results
|
| 247 |
+
// concatenated to achieve the same effect as the original layer.
|
| 248 |
+
void SplitOutputs(
|
| 249 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
|
| 250 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
|
| 251 |
+
LOG(INFO) << "input rows=" << sparse_matrix_.rows()
|
| 252 |
+
<< ", cols=" << sparse_matrix_.cols();
|
| 253 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
|
| 254 |
+
sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2));
|
| 255 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix2(sparse_matrix_.SplitByRow(
|
| 256 |
+
sparse_matrix_.rows() / 2, sparse_matrix_.rows()));
|
| 257 |
+
CacheAlignedVector<BiasType> bias1(full_bias_, 0, full_bias_.size() / 2);
|
| 258 |
+
*part1 =
|
| 259 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
| 260 |
+
std::move(matrix1), std::move(bias1)));
|
| 261 |
+
CacheAlignedVector<BiasType> bias2(full_bias_, full_bias_.size() / 2,
|
| 262 |
+
full_bias_.size());
|
| 263 |
+
*part2 =
|
| 264 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
| 265 |
+
std::move(matrix2), std::move(bias2)));
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
private:
|
| 269 |
+
// Simple struct to hold a partitioned layer.
|
| 270 |
+
struct PartLinearLayer {
|
| 271 |
+
// The original matrix is first split by row to generate only the outputs
|
| 272 |
+
// for the given tid. The |row_sub_matrix| is then split by column into two
|
| 273 |
+
// partitions:
|
| 274 |
+
// self is the part for which the rhs elements in [|start_col|, |end_col|)
|
| 275 |
+
// were generated by this thread in some previous matmul.
|
| 276 |
+
// |other| is the rest of the columns that require rhs elements from other
|
| 277 |
+
// threads.
|
| 278 |
+
// NOTE that| start_col|, |end_col| are in raw columns, not blocks.
|
| 279 |
+
PartLinearLayer(const CsrBlockSparseMatrix<WeightType, RhsType>& matrix,
|
| 280 |
+
const CacheAlignedVector<BiasType>& bias,
|
| 281 |
+
const CacheAlignedVector<BiasType>& bias_4, int tid,
|
| 282 |
+
int start_col, int end_col) {
|
| 283 |
+
int block_height = matrix.block_height();
|
| 284 |
+
// Split the input matrix by row, selecting only the rows relevant to
|
| 285 |
+
// thread tid.
|
| 286 |
+
int start_row = matrix.split_points()[tid] * block_height;
|
| 287 |
+
int end_row = matrix.split_points()[tid + 1] * block_height;
|
| 288 |
+
LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows ["
|
| 289 |
+
<< start_row << "," << end_row << ")";
|
| 290 |
+
CsrBlockSparseMatrix<WeightType, RhsType> row_sub_matrix =
|
| 291 |
+
matrix.SplitByRow(start_row, end_row);
|
| 292 |
+
// Partition into the columns that use rhs elements that thread tid
|
| 293 |
+
// produced in a previous matmul, and the other rhs elements.
|
| 294 |
+
// NOTE that we |keep_rhs_size|=true so that each matrix can operate on
|
| 295 |
+
// the same rhs input vector. The self matrix just guarantees not to
|
| 296 |
+
// access any of the elements that are generated by another thread.
|
| 297 |
+
self_matrix = std::move(row_sub_matrix.SplitByColumn(
|
| 298 |
+
start_col, end_col, /*keep_rhs_size=*/true));
|
| 299 |
+
self_matrix.PrepareForThreads(1);
|
| 300 |
+
// The reversed start and end slice out the complement of [start, end).
|
| 301 |
+
other_matrix = std::move(row_sub_matrix.SplitByColumn(
|
| 302 |
+
end_col, start_col, /*keep_rhs_size=*/true));
|
| 303 |
+
other_matrix.PrepareForThreads(1);
|
| 304 |
+
full_bias =
|
| 305 |
+
std::move(CacheAlignedVector<BiasType>(bias, start_row, end_row));
|
| 306 |
+
// TODO(b/189958858): Eliminate the quarter bias from all the code.
|
| 307 |
+
quarter_bias =
|
| 308 |
+
std::move(CacheAlignedVector<BiasType>(bias_4, start_row, end_row));
|
| 309 |
+
}
|
| 310 |
+
// The part of the matrix that only depends on this thread for rhs inputs.
|
| 311 |
+
CsrBlockSparseMatrix<WeightType, RhsType> self_matrix;
|
| 312 |
+
CacheAlignedVector<BiasType> full_bias;
|
| 313 |
+
CacheAlignedVector<BiasType> quarter_bias;
|
| 314 |
+
// The part of the matrix that uses rhs inputs from other threads.
|
| 315 |
+
CsrBlockSparseMatrix<WeightType, RhsType> other_matrix;
|
| 316 |
+
};
|
| 317 |
+
CsrBlockSparseMatrix<WeightType, RhsType, DeltaType> sparse_matrix_;
|
| 318 |
+
CacheAlignedVector<BiasType> bias_;
|
| 319 |
+
CacheAlignedVector<BiasType> full_bias_;
|
| 320 |
+
// Output from the self_matrix that will be given to |other_matrix| as bias.
|
| 321 |
+
CacheAlignedVector<BiasType> mid_output_;
|
| 322 |
+
// One partitioned pair of matrices for each thread.
|
| 323 |
+
std::vector<PartLinearLayer> thread_layers_;
|
| 324 |
+
// Producer-consumer lock used to wait between computing |self_matrix| and
|
| 325 |
+
// |other_matrix| for the other threads to finish the *previous* matvec.
|
| 326 |
+
std::unique_ptr<ProducerConsumer> split_pc_;
|
| 327 |
+
int num_threads_ = 0;
|
| 328 |
+
};
|
| 329 |
+
|
| 330 |
+
template <typename WeightType, typename RhsType>
|
| 331 |
+
SparseLinearLayer<WeightType, RhsType> CreateRandomLayer(int rows, int cols,
|
| 332 |
+
float sparsity,
|
| 333 |
+
int block_height = 1,
|
| 334 |
+
int block_width = 1) {
|
| 335 |
+
typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
|
| 336 |
+
CacheAlignedVector<BiasType> bias(rows);
|
| 337 |
+
bias.FillRandom();
|
| 338 |
+
|
| 339 |
+
auto masked_matrix = MaskedSparseMatrix<float>(rows, cols, sparsity,
|
| 340 |
+
block_height, block_width);
|
| 341 |
+
auto sparse_matrix = CsrBlockSparseMatrix<WeightType, RhsType>(masked_matrix);
|
| 342 |
+
|
| 343 |
+
return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
|
| 344 |
+
std::move(bias));
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
template <typename WeightType, typename RhsType>
|
| 348 |
+
SparseLinearLayer<WeightType, RhsType> CreateConstantLayer(
|
| 349 |
+
int rows, int cols, float sparsity, float constant = 1.f) {
|
| 350 |
+
typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
|
| 351 |
+
CacheAlignedVector<BiasType> bias(rows);
|
| 352 |
+
bias.FillOnes();
|
| 353 |
+
|
| 354 |
+
MaskedSparseMatrix<float> masked_matrix(rows, cols, sparsity,
|
| 355 |
+
/*block_height=*/1, /*block_width=*/1,
|
| 356 |
+
constant, /*random=*/false);
|
| 357 |
+
CsrBlockSparseMatrix<WeightType, RhsType> sparse_matrix(masked_matrix);
|
| 358 |
+
|
| 359 |
+
return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
|
| 360 |
+
std::move(bias));
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
} // namespace csrblocksparse
|
| 364 |
+
|
| 365 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
|
sparse_matmul/layers/sparse_linear_layer_test.cc
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#include "sparse_matmul/layers/sparse_linear_layer.h"
|
| 16 |
+
|
| 17 |
+
#include "gmock/gmock.h"
|
| 18 |
+
#include "gtest/gtest.h"
|
| 19 |
+
#include "sparse_matmul/numerics/test_utils.h"
|
| 20 |
+
|
| 21 |
+
namespace csrblocksparse {
|
| 22 |
+
namespace {
|
| 23 |
+
|
| 24 |
+
constexpr int kBlockSize = 4;
|
| 25 |
+
constexpr int kSize = 256;
|
| 26 |
+
constexpr int kNumThreads = 4;
|
| 27 |
+
constexpr int kCols = 1;
|
| 28 |
+
|
| 29 |
+
void SlicedThreadBody(SpinBarrier* spin_barrier, int tid,
|
| 30 |
+
const FatCacheAlignedVector<float>& rhs,
|
| 31 |
+
SparseLinearLayer<float, float>* sparse_linear_layer,
|
| 32 |
+
FatCacheAlignedVector<float>* out, bool use_relu) {
|
| 33 |
+
sparse_linear_layer->MatVec(rhs, use_relu, tid, /*replicas=*/1,
|
| 34 |
+
/*output_stride=*/0, out);
|
| 35 |
+
spin_barrier->barrier();
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
// Tests that a Layer that has been SliceForThreads computes the same result as
|
| 39 |
+
// the original layer. This is a basic test that all the slicing didn't mess up
|
| 40 |
+
// any of the computations.
|
| 41 |
+
TEST(CsrBlockSparseMatrix, SliceForThreads) {
|
| 42 |
+
MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
| 43 |
+
FatCacheAlignedVector<float> rhs(kSize, kCols);
|
| 44 |
+
CacheAlignedVector<float> bias(kSize);
|
| 45 |
+
FatCacheAlignedVector<float> out1(kSize, kCols);
|
| 46 |
+
|
| 47 |
+
bias.FillRandom();
|
| 48 |
+
rhs.FillRandom();
|
| 49 |
+
out1.FillZero();
|
| 50 |
+
FatCacheAlignedVector<float> out_reference = out1;
|
| 51 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
| 52 |
+
SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
|
| 53 |
+
std::move(bias));
|
| 54 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 55 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
| 56 |
+
/*output_stride=*/0, &out_reference);
|
| 57 |
+
std::vector<int> fake_split_points = {0, 48 / kBlockSize, 128 / kBlockSize,
|
| 58 |
+
208 / kBlockSize, kSize / kBlockSize};
|
| 59 |
+
sparse_linear_layer.PrepareForThreads(kNumThreads);
|
| 60 |
+
sparse_linear_layer.SliceForThreads(fake_split_points);
|
| 61 |
+
csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, SlicedThreadBody, rhs,
|
| 62 |
+
&sparse_linear_layer, &out1,
|
| 63 |
+
/*relu=*/true);
|
| 64 |
+
|
| 65 |
+
CheckResult(out_reference, out1, kCols);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
void LayersThreadBody(SpinBarrier* spin_barrier, int tid,
|
| 69 |
+
const FatCacheAlignedVector<float>& rhs,
|
| 70 |
+
SparseLinearLayer<float, float>* sparse_linear_layer1,
|
| 71 |
+
SparseLinearLayer<float, float>* sparse_linear_layer2,
|
| 72 |
+
FatCacheAlignedVector<float>* out1,
|
| 73 |
+
FatCacheAlignedVector<float>* out2, bool use_relu) {
|
| 74 |
+
sparse_linear_layer1->MatVec(rhs, use_relu, tid, /*replicas=*/1,
|
| 75 |
+
/*output_stride=*/0, out1);
|
| 76 |
+
// NOTE no barrier here!
|
| 77 |
+
sparse_linear_layer2->MatVec(*out1, use_relu, tid, /*replicas=*/1,
|
| 78 |
+
/*output_stride=*/0, out2);
|
| 79 |
+
spin_barrier->barrier();
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// Tests that a pair of layers computes the same result whether or not the
|
| 83 |
+
// second layer has been SliceForThreads. This is a more critical test that
|
| 84 |
+
// the replacement of barriers with producer-consumer locks works.
|
| 85 |
+
// Must be run with tsan to really test it properly.
|
| 86 |
+
TEST(CsrBlockSparseMatrix, SliceForThreadsLayers) {
|
| 87 |
+
MaskedSparseMatrix<float> matrix1(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
| 88 |
+
FatCacheAlignedVector<float> rhs(kSize, kCols);
|
| 89 |
+
CacheAlignedVector<float> bias1(kSize);
|
| 90 |
+
FatCacheAlignedVector<float> out1(kSize, kCols);
|
| 91 |
+
MaskedSparseMatrix<float> matrix2(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
| 92 |
+
CacheAlignedVector<float> bias2(kSize);
|
| 93 |
+
FatCacheAlignedVector<float> out2(kSize, kCols);
|
| 94 |
+
|
| 95 |
+
bias1.FillRandom();
|
| 96 |
+
rhs.FillRandom();
|
| 97 |
+
bias2.FillRandom();
|
| 98 |
+
out1.FillZero();
|
| 99 |
+
out2.FillZero();
|
| 100 |
+
FatCacheAlignedVector<float> out_reference = out2;
|
| 101 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix1(matrix1);
|
| 102 |
+
SparseLinearLayer<float, float> layer1(std::move(sparse_matrix1),
|
| 103 |
+
std::move(bias1));
|
| 104 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix2(matrix2);
|
| 105 |
+
SparseLinearLayer<float, float> layer2(std::move(sparse_matrix2),
|
| 106 |
+
std::move(bias2));
|
| 107 |
+
layer1.PrepareForThreads(1);
|
| 108 |
+
layer2.PrepareForThreads(1);
|
| 109 |
+
layer1.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
| 110 |
+
/*output_stride=*/0, &out1);
|
| 111 |
+
layer2.MatVec(out1, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
| 112 |
+
/*output_stride=*/0, &out_reference);
|
| 113 |
+
layer1.PrepareForThreads(kNumThreads);
|
| 114 |
+
layer2.PrepareForThreads(kNumThreads);
|
| 115 |
+
layer2.SliceForThreads(layer1.split_points());
|
| 116 |
+
csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, LayersThreadBody, rhs,
|
| 117 |
+
&layer1, &layer2, &out1, &out2,
|
| 118 |
+
/*relu=*/true);
|
| 119 |
+
|
| 120 |
+
CheckResult(out_reference, out2, kCols);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
|
| 124 |
+
// result as original layer. (Float compute type).
|
| 125 |
+
TEST(CsrBlockSparseMatrix, Float8x4) {
|
| 126 |
+
using ComputeType = float;
|
| 127 |
+
using RhsType = float;
|
| 128 |
+
using BiasType = float;
|
| 129 |
+
MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
| 130 |
+
matrix.CastWeights<ComputeType>();
|
| 131 |
+
FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
|
| 132 |
+
CacheAlignedVector<BiasType> bias(kSize);
|
| 133 |
+
FatCacheAlignedVector<BiasType> out1(kSize, kCols);
|
| 134 |
+
|
| 135 |
+
bias.FillRandom();
|
| 136 |
+
rhs.FillRandom();
|
| 137 |
+
out1.FillZero();
|
| 138 |
+
FatCacheAlignedVector<BiasType> out_reference = out1;
|
| 139 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
| 140 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
| 141 |
+
std::move(sparse_matrix), std::move(bias));
|
| 142 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 143 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
| 144 |
+
/*output_stride=*/0, &out_reference);
|
| 145 |
+
sparse_linear_layer.DoubleBlockHeight();
|
| 146 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 147 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
| 148 |
+
/*output_stride=*/0, &out1);
|
| 149 |
+
CheckResult(out_reference, out1, kCols);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
|
| 153 |
+
// result as original layer. (Fixed16 compute type).
|
| 154 |
+
TEST(CsrBlockSparseMatrix, Fixed8x4) {
|
| 155 |
+
using ComputeType = csrblocksparse::fixed16<4>;
|
| 156 |
+
using RhsType = csrblocksparse::fixed16<4>;
|
| 157 |
+
using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
|
| 158 |
+
MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
| 159 |
+
matrix.CastWeights<ComputeType>();
|
| 160 |
+
FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
|
| 161 |
+
CacheAlignedVector<BiasType> bias(kSize);
|
| 162 |
+
FatCacheAlignedVector<BiasType> out1(kSize, kCols);
|
| 163 |
+
|
| 164 |
+
bias.FillRandom();
|
| 165 |
+
rhs.FillRandom();
|
| 166 |
+
out1.FillZero();
|
| 167 |
+
FatCacheAlignedVector<BiasType> out_reference = out1;
|
| 168 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
| 169 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
| 170 |
+
std::move(sparse_matrix), std::move(bias));
|
| 171 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 172 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
|
| 173 |
+
/*output_stride=*/0, &out_reference);
|
| 174 |
+
sparse_linear_layer.DoubleBlockHeight();
|
| 175 |
+
sparse_linear_layer.PrepareForThreads(1);
|
| 176 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
|
| 177 |
+
/*output_stride=*/0, &out1);
|
| 178 |
+
CheckResult(out_reference, out1, kCols);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
TEST(SparseLinearLayerTest, PrintCompiles) {
|
| 182 |
+
SparseLinearLayer<float, float> sparse_linear_layer;
|
| 183 |
+
sparse_linear_layer.Print();
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
} // namespace
|
| 187 |
+
} // namespace csrblocksparse
|
sparse_matmul/layers/status_macros.h
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2021 Google LLC
|
| 2 |
+
//
|
| 3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
// you may not use this file except in compliance with the License.
|
| 5 |
+
// You may obtain a copy of the License at
|
| 6 |
+
//
|
| 7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
//
|
| 9 |
+
// Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
// See the License for the specific language governing permissions and
|
| 13 |
+
// limitations under the License.
|
| 14 |
+
|
| 15 |
+
#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
|
| 16 |
+
#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
|
| 17 |
+
|
| 18 |
+
#include "absl/status/status.h"
|
| 19 |
+
#include "absl/status/statusor.h"
|
| 20 |
+
|
| 21 |
+
#define SPARSE_MATMUL_RETURN_IF_ERROR(expr) \
|
| 22 |
+
do { \
|
| 23 |
+
const absl::Status _status = (expr); \
|
| 24 |
+
if (!_status.ok()) return _status; \
|
| 25 |
+
} while (0)
|
| 26 |
+
template <typename T>
|
| 27 |
+
absl::Status DoAssignOrReturn(T& lhs, absl::StatusOr<T> result) {
|
| 28 |
+
if (result.ok()) {
|
| 29 |
+
lhs = result.value();
|
| 30 |
+
}
|
| 31 |
+
return result.status();
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
|
sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50f861af29b1f767830d74ef83874944b18d80157b6b0256fdc4c14fa79ec936
|
| 3 |
+
size 20852
|
sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2d534bde2caf6e59990a46b4b1907088b8144c53d62d97de7e2b4bdc956da68
|
| 3 |
+
size 5133
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11399f9d0e8f8dfbef6eb37e0c096f858658bc650f728a08f3135ccca44f0a5a
|
| 3 |
+
size 1062
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3d971e067a6df985d68beac26bcf4e9a6cc13ff328599e84d50a0fc9a7c103b
|
| 3 |
+
size 2382
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d1376ef7a360699dae24a49f40a254990d4a70b844dadcdbe9dcbf1a306999a8
|
| 3 |
+
size 55829
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffcc8ccf086fccfacc928877aa29ef03ce51cce0f0b7d2aacf81782b7b527089
|
| 3 |
+
size 2003
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a16f98ba6f09031ea9fefb79fdc9ba90e44f0046ab70dab014ac971ca7f7186
|
| 3 |
+
size 4684
|