// Copyright (c) Meta Platforms, Inc. and affiliates. // All rights reserved. // // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. #ifndef cudadispatch_h_ #define cudadispatch_h_ #include #include #include template struct get_base { typedef T type; }; template struct get_base::value>::type> { typedef std::shared_ptr type; }; template struct is_shared_ptr : std::false_type {}; template struct is_shared_ptr> : std::true_type {}; template auto convert_shptr_impl2(std::shared_ptr t) { return *static_cast(t.get()); } template auto convert_shptr_impl(T&& t, std::false_type) { return convert_shptr_impl2(t); } template auto convert_shptr_impl(T&& t, std::true_type) { return std::forward(t); } template auto convert_shptr(T&& t) { return convert_shptr_impl(std::forward(t), std::is_same{}); } template struct cudacall { struct functbase { virtual ~functbase() {} virtual void call(dim3, dim3, cudaStream_t, ArgsIn...) const = 0; }; template struct funct : public functbase { std::function fn; funct(void(*fn_)(ArgsOut...)) : fn(fn_) { } void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsIn... args) const { void (*const*kfunc)(ArgsOut...) = fn.template target(); (*kfunc)<<>>( std::forward(convert_shptr(std::forward(args)))...); } }; std::shared_ptr fn; template cudacall(void(*fn_)(ArgsOut...)) : fn(std::make_shared>(fn_)) { } template void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsTmp&&... args) const { fn->call(gridsize, blocksize, stream, std::forward(args)...); } }; template struct binder { F f; T t; template auto operator()(Args&&... args) const -> decltype(f(t, std::forward(args)...)) { return f(t, std::forward(args)...); } }; template binder::type , typename std::decay::type> BindFirst(F&& f, T&& t) { return { std::forward(f), std::forward(t) }; } template auto make_cudacall_(void(*fn)(ArgsOut...)) { return BindFirst( std::mem_fn(&cudacall::type...>::template call::type...>), cudacall::type...>(fn)); } template std::function::type...)> make_cudacall(void(*fn)(ArgsOut...)) { return std::function::type...)>(make_cudacall_(fn)); } #endif