#!/usr/bin/python3 # Copyright (C) 2023 Free Software Foundation, Inc. # This file is part of the GNU C Library. # # The GNU C Library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public # License as published by the Free Software Foundation; either # version 2.1 of the License, or (at your option) any later version. # # The GNU C Library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with the GNU C Library; if not, see # . import sys TEMPLATE = """ #include #include #define STRIDE {stride} #define CALL_BENCH_FUNC(v, i) (__extension__ ({{ \\ {rtype} mx0 = {fname}(vld1q_f{prec_short} (variants[v].in[i].arg0)); \\ mx0; }})) struct args {{ {stype} arg0[STRIDE]; double timing; }}; struct _variants {{ const char *name; int count; const struct args *in; }}; static const struct args in0[{rowcount}] = {{ {in_data} }}; static const struct _variants variants[1] = {{ {{"", {rowcount}, in0}}, }}; #define NUM_VARIANTS 1 #define NUM_SAMPLES(i) (variants[i].count) #define VARIANT(i) (variants[i].name) static {rtype} volatile ret; #define BENCH_FUNC(i, j) ({{ ret = CALL_BENCH_FUNC(i, j); }}) #define FUNCNAME "{fname}" #include """ def main(name): _, prec, _, func = name.split("-") scalar_to_advsimd_type = {"double": "float64x2_t", "float": "float32x4_t"} stride = {"double": 2, "float": 4}[prec] rtype = scalar_to_advsimd_type[prec] atype = scalar_to_advsimd_type[prec] fname = f"_ZGVnN{stride}v_{func}{'f' if prec == 'float' else ''}" prec_short = {"double": 64, "float": 32}[prec] with open(f"../benchtests/libmvec/{func}-inputs") as f: in_vals = [l.strip() for l in f.readlines() if l and not l.startswith("#")] in_vals = [in_vals[i:i+stride] for i in range(0, len(in_vals), stride)] rowcount= len(in_vals) in_data = ",\n".join("{{" + ", ".join(row) + "}, 0}" for row in in_vals) print(TEMPLATE.format(stride=stride, rtype=rtype, atype=atype, fname=fname, prec_short=prec_short, in_data=in_data, rowcount=rowcount, stype=prec)) if __name__ == "__main__": main(sys.argv[1])