雕虫小技:用 C Macro 进行『模板编程』

最近和一个 C++ 程序员聊的时候被对方说写 C 不能用模板编程,会浪费很多时间。我想对方作为一个写了十多年 C++ 的老程序员肯定知道 C 能用 macro 玩出许多花样,只是不想让我这种菜鸡被这些禁忌领域的东西荼毒心灵。于是我自己找了一下资料,搞了一个满足我需求的方法出来,记录如下。

什么时候写 C 需要用模板编程?对于我来说,主要考虑如下两种情况:

  1. 有时候我们需要为同一个函数创建支持不同数据类型(最经典的场景:double / float)的实现。
  2. 有时候我们可以知道某个函数的某些参数只有为数不多的取值,而这些参数的取值直接决定了函数中的分支选择 and/or 循环长度 and/or 数组下标 and/or 其他影响性能的部分。

第一种情况很容易用 macro 解决,大部分时候简单把变量类型改成一个 macro 就可以了,只有涉及针对类型的外部函数(比如数学函数)时比较麻烦。第二种情况的解决方案相对复杂一点点。先考虑手动解决第二种情况的话我们要怎么做。一般来说,我们会给每一个固定参数值创建一个函数的特例。这个特例的函数名一般要体现这个固定的参数值,这样在随后调用的时候可以辨别。在这个特例的函数体里,那个被固定下来的参数变成了一个常量,这样编译器可以在编译期尽可能多地计算一些信息来优化生成的代码。下面的代码是我用 macro 写的针对第二种情况的解决方案。

这个样例代码里用到的技巧是 macro concatenation:在 macro 定义中,## 将连接起它两边空格以外的两个 macro 或者字符串和一个 macro。代码第 8、9 行用来生成带有固定参数值的函数名,用 ## 拼接起了函数名固定的部分和函数名里的固定参数值。第 10-12 行是函数的其他输入输出参数。第 13-20 行是模板函数的函数体,函数体内的 PARAM1PARAM2 会被替换成常数。需要注意的是,如果模板函数体内需要用 pragma,不能用 #pragma 的形式,需要用第 16 行的 _Pragma。第 22-24 行用不同的参数组合创建了三个模板函数的特列,并在后面被调用了。

需要指出,使用这样的 macro template programming 可能有如下弊端:

  1. 部分 profiler 无法正确统计或者显示代码块或者每一行的耗时。
  2. 部分编辑器无法对 macro 内的代码进行语法高亮显示(Sublime Text 3 无障碍,VS Code 部分关键词高亮,Notepad++、Markdown 代码块全灰,……),所有编辑器都无法找到函数声明和进行函数名/参数补全。

另外再吐槽一下 GCC。下面这个代码在 Compiler Explorer 里用 GCC 8.3 编译的话,GCC 8.3 还傻乎乎地用 128 位的 xmm 寄存器来和 vfmadd213sd 指令搭配(证明的确是在用 AVX2 指令集),GCC 9.1 开始才正确地使用了 256 位的 ymm 寄存器……

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// Compile: gcc -O3 -std=gnu99 -march=core-avx2 -fopenmp -Wall marco_template.c -o marco_template.exe
// Reference: 1. https://gcc.gnu.org/onlinedocs/cpp/Macros.html
// 2. https://stackoverflow.com/questions/1253934/c-pre-processor-defining-for-generated-function-names
#include <stdio.h>
#include <string.h>
#define VEC_FUN_TEMPLATE(PARAM1, PARAM2) \
void vec_func_ ## PARAM1 ## _ ## PARAM2 ( \
const double *x, const double *y, \
double *__restrict__ z \
) \
{ \
for (int i = 0; i < PARAM1; i++) \
{ \
_Pragma("omp simd") \
for (int j = 0; j < PARAM2; j++) \
z[j] += x[i] * y[j]; \
} \
}
VEC_FUN_TEMPLATE(2, 4) // This gives you vec_func_2_4(x, y, z)
VEC_FUN_TEMPLATE(2, 8) // This gives you vec_func_2_8(x, y, z)
VEC_FUN_TEMPLATE(4, 8) // This gives you vec_func_4_8(x, y, z)
int main()
{
double x[8] = {0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0};
double y[8] = {1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0};
double z[8];
memset(z, 0, sizeof(double) * 8);
vec_func_2_4(x, y, z);
for (int i = 0; i < 8; i++) printf("%.2lf ", z[i]);
printf("\n");
memset(z, 0, sizeof(double) * 8);
vec_func_2_8(x, y, z);
for (int i = 0; i < 8; i++) printf("%.2lf ", z[i]);
printf("\n");
memset(z, 0, sizeof(double) * 8);
vec_func_4_8(x, y, z);
for (int i = 0; i < 8; i++) printf("%.2lf ", z[i]);
printf("\n");
return 0;
}