|  | import torch | 
					
						
						|  | import megablocks | 
					
						
						|  |  | 
					
						
						|  | def test_import(): | 
					
						
						|  | """Simple test to check if the module can be imported.""" | 
					
						
						|  | print("megablocks_moe module imported successfully.") | 
					
						
						|  | print("Available functions:", dir(megablocks)) | 
					
						
						|  |  | 
					
						
						|  | expected_functions = [ | 
					
						
						|  | "Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP", | 
					
						
						|  | "SparseGLU", "SparseMLP", "argsort", | 
					
						
						|  | "backend", "cumsum", "dMoE", "exclusive_cumsum", | 
					
						
						|  | "get_load_balancing_loss", "grouped_gemm_util", "histogram", | 
					
						
						|  | "inclusive_cumsum", "indices", "layers", "ops", "replicate_backward", | 
					
						
						|  | "replicate_forward", "sort", "torch" | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for func in expected_functions: | 
					
						
						|  | assert func in dir(megablocks), f"Missing function: {func}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_exclusive_cumsum(): | 
					
						
						|  | """Test exclusive cumulative sum.""" | 
					
						
						|  | x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda() | 
					
						
						|  | out = torch.empty_like(x) | 
					
						
						|  | megablocks.exclusive_cumsum(x, 0, out) | 
					
						
						|  | expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda() | 
					
						
						|  | assert torch.equal(out, expected), f"Expected {expected}, got {out}" | 
					
						
						|  | print("cumsum output:", out) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_inclusive_cumsum(): | 
					
						
						|  | """Test inclusive cumulative sum.""" | 
					
						
						|  | x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda() | 
					
						
						|  | out = torch.empty_like(x) | 
					
						
						|  | megablocks.inclusive_cumsum(x, dim=0, out=out) | 
					
						
						|  | expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda() | 
					
						
						|  | assert torch.equal(out, expected), f"Expected {expected}, got {out}" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def test_histogram(): | 
					
						
						|  | """Test histogram operation.""" | 
					
						
						|  | x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda() | 
					
						
						|  | num_bins = 3 | 
					
						
						|  | hist = megablocks.histogram(x, num_bins) | 
					
						
						|  | expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda() | 
					
						
						|  | assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}" | 
					
						
						|  |  |