Description:Code optimization

Feature or Bugfix:Bugfix
Binary Source:No

Signed-off-by: lwx1281857 <linnanmu@h-partners.com>
This commit is contained in:
lwx1281857 2024-07-11 19:51:10 +08:00
parent 9f00689092
commit 43392e239e
7 changed files with 157 additions and 143 deletions

View File

@ -91,12 +91,12 @@ def write_map_to_code(code_name, data_dict):
f.write('typedef struct Node_ {' + os.linesep)
f.write(' const char *name;' + os.linesep)
f.write(' const char *value;' + os.linesep)
f.write('} Node;' + os.linesep + os.linesep)
f.write('} Node;' + os.linesep + os.linesep)
f.write('#define PARAM_MAP(name, value) {(const char *)#name, (const char *)#value},')
f.write(os.linesep + os.linesep)
f.write(os.linesep + os.linesep)
# write data
f.write('static Node g_paramDefCfgNodes[] = {' + os.linesep)
for name, value in data_dict.items():
for name, value in data_dict.items():
if (value.startswith("\"")):
tmp_str = " PARAM_MAP({0}, {1})".format(name, value)
f.write(tmp_str + os.linesep)
@ -120,7 +120,7 @@ def write_map_to_code(code_name, data_dict):
def add_to_code_dict(code_dict, cfg_dict, high=True):
for name, value in cfg_dict.items():
for name, value in cfg_dict.items():
# check if name exit
has_key = name in code_dict
if has_key and high:

View File

@ -23,6 +23,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir,
os.pardir, os.pardir, os.pardir, os.pardir, "build"))
from scripts.util import build_utils # noqa: E402
def parse_args(args):
args = build_utils.expand_file_args(args)
@ -35,6 +36,7 @@ def parse_args(args):
options, _ = parser.parse_args(args)
return options
def parse_params(line, contents):
line = line.strip()
pos = line.find('=')
@ -46,11 +48,13 @@ def parse_params(line, contents):
value = value.strip()
contents[name] = value
def parse_extra_params(extras, contents):
for extra in extras:
extra = extra.strip()
parse_params(extra, contents)
def fix_para_file(options):
contents = {}
@ -73,6 +77,7 @@ def fix_para_file(options):
for key in contents:
f.write("".join([key, "=", contents[key], '\n']))
def main(args):
options = parse_args(args)

View File

@ -106,72 +106,80 @@ def append_group_files(target_f, options):
for item in source_dict:
target_f.write(f"{item}:{':'.join(source_dict[item])}\n")
def handle_passwd_info(passwdInfo, limits):
isPassed = True
name = passwdInfo[0].strip()
gid = int(passwdInfo[3], 10)
uid = int(passwdInfo[2], 10)
def handle_passwd_info(passwd_info, limits):
is_passed = True
name = passwd_info[0].strip()
gid = int(passwd_info[3], 10)
uid = int(passwd_info[2], 10)
if gid >= int(limits[0]) and gid <= int(limits[1]):
pass
else:
isPassed = False
is_passed = False
log_str = "error: name={} gid={} is not in range {}".format(name, gid, limits)
print(log_str)
if uid >= int(limits[0]) and uid <= int(limits[1]):
pass
else:
isPassed = False
is_passed = False
log_str = "error: name={} uid={} is not in range {}".format(name, gid, limits)
print(log_str)
return isPassed
return is_passed
def check_passwd_file(file_name, limits):
isPassed = True
is_passed = True
with open(file_name, encoding='utf-8') as fp:
line = fp.readline()
while line :
if line.startswith("#") or len(line) < 3:
line = fp.readline()
continue
passwdInfo = line.strip("\n").split(":")
if len (passwdInfo) < 4:
passwd_info = line.strip("\n").split(":")
if len(passwd_info) < 4:
line = fp.readline()
continue
if not handle_passwd_info(passwdInfo, limits):
isPassed = False
if not handle_passwd_info(passwd_info, limits):
is_passed = False
line = fp.readline()
return isPassed
return is_passed
def load_file(file_name, limit):
if not os.path.exists(file_name):
print("error: %s is not exit", file_name)
return False
isPassed = True
is_passed = True
limits = limit.split("-")
try:
isPassed = check_passwd_file(file_name, limits)
is_passed = check_passwd_file(file_name, limits)
except:
raise Exception("Exception in reading passwd, file name:", file_name)
return isPassed
return is_passed
def append_passwd_files(target_f, options):
# Read source file
file_list = options.source_file.split(":")
range_list = options.input_ranges.split(":")
for i in range(len(file_list)):
if not load_file(file_list[i], range_list[i]):
# check gid/uid Exception log: raise Exception("Exception, check passwd file error, ", file_list[i])
print("error: heck passwd file error, file path: ", file_list[i])
for i, file in enumerate(file_list):
if i >= len(range_list):
print("error: %s is error", file)
return
if not load_file(file, range_list[i]):
# check gid/uid Exception log: raise Exception("Exception, check passwd file error, ", file)
print("error: heck passwd file error, file path: ", file)
pass
try:
with open(file_list[i], 'r') as source_f:
with open(file, 'r') as source_f:
source_contents = source_f.read()
target_f.write(source_contents)
except:
raise Exception("Exception in appending passwd, file name:", file_list[i])
raise Exception("Exception in appending passwd, file name:", file)
def main(args):
sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir,

View File

@ -27,7 +27,7 @@ supported_parse_item = ['labelName', 'priority', 'allowList', 'blockList', 'prio
'allowListWithArgs', 'headFiles', 'selfDefineSyscall', 'returnValue', \
'mode', 'privilegedProcessName', 'allowBlockList']
supported_architecture = ['arm', 'arm64','riscv64']
supported_architecture = ['arm', 'arm64', 'riscv64']
BPF_JGE = 'BPF_JUMP(BPF_JMP|BPF_JGE|BPF_K, {}, {}, {}),'
BPF_JGT = 'BPF_JUMP(BPF_JMP|BPF_JGT|BPF_K, {}, {}, {}),'
@ -343,6 +343,89 @@ class GenBpfPolicy:
'&' : self.gen_bpf_set,
}
@staticmethod
def gen_bpf_eq32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_eq64(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_gt32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_gt64(const_str, jt, jf):
bpf_policy = []
number, digit_flag = str_convert_to_int(const_str)
hight = int(number / (2**32))
low = number & 0xffffffff
if digit_flag and hight == 0:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
else:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_ge32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_ge64(const_str, jt, jf):
bpf_policy = []
number, digit_flag = str_convert_to_int(const_str)
hight = int(number / (2**32))
low = number & 0xffffffff
if digit_flag and hight == 0:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
else:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_set32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_set64(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JSET.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_valid_syscall_nr(syscall_nr, cur_size):
bpf_policy = []
bpf_policy.append(BPF_LOAD.format(0))
bpf_policy.append(BPF_JEQ.format(syscall_nr, 0, cur_size))
return bpf_policy
def update_arch(self, arch):
self.arch = arch
self.syscall_nr_range = []
@ -377,20 +460,6 @@ class GenBpfPolicy:
self.return_value = return_value
@staticmethod
def gen_bpf_eq32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_eq64(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JEQ.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
def gen_bpf_eq(self, const_str, jt, jf):
if self.arch == 'arm':
return self.gen_bpf_eq32(const_str, jt, jf)
@ -401,31 +470,6 @@ class GenBpfPolicy:
def gen_bpf_ne(self, const_str, jt, jf):
return self.gen_bpf_eq(const_str, jf, jt)
@staticmethod
def gen_bpf_gt32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_gt64(const_str, jt, jf):
bpf_policy = []
number, digit_flag = str_convert_to_int(const_str)
hight = int(number / (2**32))
low = number & 0xffffffff
if digit_flag and hight == 0:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
else:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JGT.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
def gen_bpf_gt(self, const_str, jt, jf):
if self.arch == 'arm':
return self.gen_bpf_gt32(const_str, jt, jf)
@ -436,29 +480,6 @@ class GenBpfPolicy:
def gen_bpf_le(self, const_str, jt, jf):
return self.gen_bpf_gt(const_str, jf, jt)
@staticmethod
def gen_bpf_ge32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_ge64(const_str, jt, jf):
bpf_policy = []
number, digit_flag = str_convert_to_int(const_str)
hight = int(number / (2**32))
low = number & 0xffffffff
if digit_flag and hight == 0:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
else:
bpf_policy.append(BPF_JGT.format('((unsigned long)' + const_str + ') >> 32', jt + 3, 0))
bpf_policy.append(BPF_JEQ.format('((unsigned long)' + const_str + ') >> 32', 0, jf + 2))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JGE.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
def gen_bpf_ge(self, const_str, jt, jf):
if self.arch == 'arm':
return self.gen_bpf_ge32(const_str, jt, jf)
@ -469,20 +490,6 @@ class GenBpfPolicy:
def gen_bpf_lt(self, const_str, jt, jf):
return self.gen_bpf_ge(const_str, jf, jt)
@staticmethod
def gen_bpf_set32(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
@staticmethod
def gen_bpf_set64(const_str, jt, jf):
bpf_policy = []
bpf_policy.append(BPF_JSET.format('((unsigned long)' + const_str + ') >> 32', jt + 2, 0))
bpf_policy.append(BPF_LOAD_MEM.format(0))
bpf_policy.append(BPF_JSET.format(const_str + ' & 0xffffffff', jt, jf))
return bpf_policy
def gen_bpf_set(self, const_str, jt, jf):
if self.arch == 'arm':
return self.gen_bpf_set32(const_str, jt, jf)
@ -490,13 +497,6 @@ class GenBpfPolicy:
return self.gen_bpf_set64(const_str, jt, jf)
return []
@staticmethod
def gen_bpf_valid_syscall_nr(syscall_nr, cur_size):
bpf_policy = []
bpf_policy.append(BPF_LOAD.format(0))
bpf_policy.append(BPF_JEQ.format(syscall_nr, 0, cur_size))
return bpf_policy
def gen_range_list(self, syscall_nr_list):
if len(syscall_nr_list) == 0:
return
@ -542,7 +542,7 @@ class GenBpfPolicy:
def nr_range_to_bpf_policy(self, cur_syscall_nr_range):
self.gen_policy_syscall_nr_list(cur_syscall_nr_range)
syscall_list_len = len(self.syscall_nr_policy_list)
syscall_list_len = len(self.syscall_nr_policy_list)
if syscall_list_len == 0:
return
@ -751,7 +751,7 @@ class GenBpfPolicy:
def parse_args(self, function_name, line, skip):
bpf_policy = []
group_info = line.split('else')
group_info = line.split('else')
else_part = group_info[-1]
group = group_info[0].split('elif')
for sub_group in group:
@ -763,7 +763,7 @@ class GenBpfPolicy:
bpf_policy.append(BPF_RET_VALUE.format(ret_str_to_bpf.get(self.return_value)))
syscall_nr = self.function_name_nr_table_dict.get(self.arch).get(function_name)
#load syscall nr
bpf_policy = self.gen_bpf_valid_syscall_nr(syscall_nr, len(bpf_policy) - skip) + bpf_policy
bpf_policy = self.gen_bpf_valid_syscall_nr(syscall_nr, len(bpf_policy) - skip) + bpf_policy
return bpf_policy
def gen_bpf_policy_with_args(self, allow_list_with_args, mode, return_value):
@ -977,7 +977,7 @@ class SeccompPolicyParser:
extra_header = set()
for arch in self.arches:
extra_header |= self.seccomp_policy_param.get(arch).head_files
extra_header_list = ['#include ' + i for i in sorted(list(extra_header))]
extra_header_list = ['#include ' + i for i in sorted(list(extra_header))]
filter_name = 'g_' + args.filter_name + 'SeccompFilter'
array_name = textwrap.dedent('''
@ -1060,7 +1060,7 @@ def main():
parser.add_argument('--dst-file',
help='The output path for the policy files')
parser.add_argument('--filter-name', type=str,
parser.add_argument('--filter-name', type=str,
help='Name of seccomp bpf array generated by this script')
parser.add_argument('--target-cpu', type=str,

View File

@ -57,7 +57,7 @@ def get_item_content(name_nr_table, arch_nr_table):
for func_name in syscall_name_dict.get('arm64'):
if func_name in syscall_name_dict.get('arm'):
content = '{}{};all\n'.format(content, func_name)
content = '{}{};all\n'.format(content, func_name)
syscall_name_dict.get('arm').remove(func_name)
else:
content = '{}{};arm64\n'.format(content, func_name)

View File

@ -31,7 +31,7 @@ class LibcFuncUnit:
self.use_function = set()
self.arch = arch
def merge_nr(self, nr):
def merge_nr(self, nr):
self.nr |= nr
def update_func_name(self, func_name):
@ -114,11 +114,11 @@ def get_direct_use_syscall_of_svc(arch, lines, func_list):
is_find_svc = True
continue
if is_find_svc and 'mov' in line and (svc_reg in line or svc_reg1 in line):
if is_find_svc and 'mov' in line and (svc_reg in line or svc_reg1 in line):
nr, is_find_nr, is_find_svc = line_find_syscall_nr(line, nr_set, nr)
continue
if is_find_nr and line[-1] == ':':
if is_find_nr and line[-1] == ':':
addr = line[:line.find(' ')]
addr = remove_head_zero(addr)
func_name = line[line.find('<') + 1: line.rfind('>')]
@ -166,7 +166,7 @@ def get_direct_use_syscall_of_syscall(arch, lines, func_list):
is_find_syscall = False
continue
if is_find_syscall_nr and line[-1] == ':':
if is_find_syscall_nr and line[-1] == ':':
addr = line[:line.find(' ')]
addr = remove_head_zero(addr)
func_name = line[line.find('<') + 1: line.rfind('>')]

View File

@ -28,27 +28,6 @@ class MergePolicy:
self.arches = set()
self.seccomp_policy_param = dict()
def update_parse_item(self, line):
item = line[1:]
if item in gen_policy.supported_parse_item:
self.cur_parse_item = item
print('start deal with {}'.format(self.cur_parse_item))
def parse_line(self, line):
if not self.cur_parse_item :
return
line = line.replace(' ', '')
pos = line.rfind(';')
if pos < 0:
for arch in self.arches:
self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line)
else:
arches = line[pos + 1:].split(',')
if arches[0] == 'all':
arches = gen_policy.supported_architecture
for arch in arches:
self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
@staticmethod
def get_item_content(name_nr_table, item_str, itme_dict):
syscall_name_dict = {}
@ -89,6 +68,28 @@ class MergePolicy:
[func_name for func_name, _ in syscall_name_dict.get('riscv64')]))
return content
def update_parse_item(self, line):
item = line[1:]
if item in gen_policy.supported_parse_item:
self.cur_parse_item = item
print('start deal with {}'.format(self.cur_parse_item))
def parse_line(self, line):
if not self.cur_parse_item :
return
line = line.replace(' ', '')
pos = line.rfind(';')
if pos < 0:
for arch in self.arches:
self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line)
else:
arches = line[pos + 1:].split(',')
if arches[0] == 'all':
arches = gen_policy.supported_architecture
for arch in arches:
self.seccomp_policy_param.get(arch).value_function.get(self.cur_parse_item)(line[:pos])
def parse_open_file(self, fp):
for line in fp:
line = line.strip()
@ -150,7 +151,7 @@ def main():
parser.add_argument('--src-files', type=str, action='append',
help=('input libsyscall_to_nr files and policy filse\n'))
parser.add_argument('--filter-name', type=str,
parser.add_argument('--filter-name', type=str,
help='Name of seccomp bpf array generated by this script')
args = parser.parse_args()