Skip to content

Commit 466e55a

Browse files
committed
feat: enhance struct_ops handling with terminal return values and tail-call logic
1 parent 193f819 commit 466e55a

2 files changed

Lines changed: 61 additions & 106 deletions

File tree

src/ir_generator.ml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,8 @@ let normalize_struct_ops_instance_name name =
388388
Buffer.contents buffer
389389

390390
let generate_default_struct_ops_name instance_name =
391+
(* BPF_OBJ_NAME_LEN is 16 bytes including the NUL terminator, so the
392+
usable name length is 15 characters. *)
391393
let max_len = 15 in
392394
let normalized = normalize_struct_ops_instance_name instance_name in
393395
if String.length normalized <= max_len then normalized
@@ -406,6 +408,22 @@ let generate_default_struct_ops_name instance_name =
406408
if String.length abbreviated <= max_len then abbreviated
407409
else String.sub abbreviated 0 max_len
408410

411+
(* Decide whether a tail-call return (IRReturnCall) should be emitted for a
412+
call to [name] in the current context.
413+
414+
Two intentional behaviour changes vs. the previous per-site inline logic:
415+
416+
1. [is_function_pointer] now checks for [IRFunctionPointer] specifically
417+
instead of [Hashtbl.mem ctx.variable_types name]. The old check was
418+
too broad: any local variable (int, pointer, …) with the same name
419+
would be treated as a function pointer and block tail-call lowering.
420+
421+
2. A tail call is only emitted when [current_program_type] is set to a
422+
known attributed type (e.g. XDP, TC, kprobe). Helper functions that
423+
are lowered outside of an attributed program context therefore never
424+
produce tail calls, which is correct because they have no prog_array
425+
to dispatch into. struct_ops methods are explicitly excluded via the
426+
[StructOps] branch. *)
409427
let should_lower_as_implicit_tail_call ctx name =
410428
let is_function_pointer =
411429
Hashtbl.mem ctx.function_parameters name ||

src/userspace_codegen.ml

Lines changed: 43 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,8 +2344,10 @@ let collect_function_usage_from_ir_function ?(global_variables = []) ir_func =
23442344

23452345
type struct_ops_main_registration = {
23462346
result_value: ir_value;
2347-
result_name: string;
2347+
result_name: string; (** variable holding the attach() return value *)
23482348
instance_name: string;
2349+
terminal_return_name: string; (** raw IR name of the variable main() returns *)
2350+
terminal_return_value: ir_value; (** ir_value of the final return - used for C name generation *)
23492351
}
23502352

23512353
let ir_value_variable_name ir_value =
@@ -2362,14 +2364,19 @@ let struct_ops_instance_name ir_value =
23622364
| IRStruct (name, _) -> Some name
23632365
| _ -> None)
23642366

2367+
(** Find the single struct_ops registration in [ir_func] and the variable
2368+
that is ultimately returned from [main]. Returns [None] if the pattern
2369+
cannot be identified unambiguously from the IR. *)
23652370
let find_struct_ops_main_registration ir_func =
23662371
let registrations = List.fold_left (fun acc block ->
23672372
List.fold_left (fun inner_acc instr ->
23682373
match instr.instr_desc with
23692374
| IRStructOpsRegister (result_val, struct_ops_val) ->
23702375
(match ir_value_variable_name result_val, struct_ops_instance_name struct_ops_val with
23712376
| Some result_name, Some instance_name ->
2372-
{ result_value = result_val; result_name; instance_name } :: inner_acc
2377+
{ result_value = result_val; result_name; instance_name;
2378+
terminal_return_name = result_name;
2379+
terminal_return_value = result_val } :: inner_acc
23732380
| _ -> inner_acc)
23742381
| _ -> inner_acc
23752382
) acc block.instructions
@@ -2378,87 +2385,14 @@ let find_struct_ops_main_registration ir_func =
23782385
| last_block :: _, [registration] ->
23792386
(match List.rev last_block.instructions with
23802387
| { instr_desc = IRReturn (Some return_val); _ } :: _ ->
2381-
if ir_value_variable_name return_val = Some registration.result_name then
2382-
Some registration
2383-
else
2384-
None
2388+
let terminal_return_name =
2389+
Option.value ~default:registration.result_name
2390+
(ir_value_variable_name return_val)
2391+
in
2392+
Some { registration with terminal_return_name; terminal_return_value = return_val }
23852393
| _ -> None)
23862394
| _ -> None
23872395

2388-
let is_c_identifier value =
2389-
let is_ident_start = function
2390-
| 'a' .. 'z' | 'A' .. 'Z' | '_' -> true
2391-
| _ -> false
2392-
in
2393-
let is_ident_char = function
2394-
| 'a' .. 'z' | 'A' .. 'Z' | '0' .. '9' | '_' -> true
2395-
| _ -> false
2396-
in
2397-
String.length value > 0
2398-
&& is_ident_start value.[0]
2399-
&&
2400-
let rec check index =
2401-
if index >= String.length value then true
2402-
else if is_ident_char value.[index] then check (index + 1)
2403-
else false
2404-
in
2405-
check 1
2406-
2407-
let extract_terminal_return_identifier body_c =
2408-
let lines = Array.of_list (String.split_on_char '\n' body_c) in
2409-
let rec drop_leading_blank_lines = function
2410-
| line :: rest when String.trim line = "" -> drop_leading_blank_lines rest
2411-
| remaining -> remaining
2412-
in
2413-
let rec find_last_nonempty index =
2414-
if index < 0 then None
2415-
else if String.trim lines.(index) = "" then find_last_nonempty (index - 1)
2416-
else Some index
2417-
in
2418-
match find_last_nonempty (Array.length lines - 1) with
2419-
| None -> None
2420-
| Some index ->
2421-
let trimmed_line = String.trim lines.(index) in
2422-
let prefix = "return " in
2423-
if String.length trimmed_line > String.length prefix
2424-
&& String.sub trimmed_line 0 (String.length prefix) = prefix
2425-
&& trimmed_line.[String.length trimmed_line - 1] = ';' then
2426-
let expr = String.sub trimmed_line (String.length prefix) (String.length trimmed_line - String.length prefix - 1) |> String.trim in
2427-
if is_c_identifier expr then
2428-
let kept_lines =
2429-
Array.to_list (Array.sub lines 0 index)
2430-
|> List.rev
2431-
|> drop_leading_blank_lines
2432-
|> List.rev
2433-
in
2434-
Some (String.concat "\n" kept_lines, expr)
2435-
else
2436-
None
2437-
else
2438-
None
2439-
2440-
let extract_attach_result_identifier body_c instance_name =
2441-
let attach_call = sprintf "attach_struct_ops_%s();" instance_name in
2442-
let extract_identifier_from_lhs line =
2443-
match String.index_opt line '=' with
2444-
| None -> None
2445-
| Some eq_index ->
2446-
let lhs = String.sub line 0 eq_index |> String.trim in
2447-
if is_c_identifier lhs then Some lhs else None
2448-
in
2449-
String.split_on_char '\n' body_c
2450-
|> List.find_map (fun line ->
2451-
if String.contains line '=' && String.contains line 'a' && String.trim line <> "" then
2452-
let trimmed_line = String.trim line in
2453-
if String.length trimmed_line >= String.length attach_call
2454-
&& String.contains trimmed_line '='
2455-
&& String.ends_with ~suffix:attach_call trimmed_line then
2456-
extract_identifier_from_lhs trimmed_line
2457-
else
2458-
None
2459-
else
2460-
None)
2461-
24622396
(** Generate config initialization from declaration defaults *)
24632397
let generate_config_initialization (config_decl : Ast.config_declaration) =
24642398
let config_name = config_decl.config_name in
@@ -2796,7 +2730,7 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con
27962730
if index = List.length ir_func.basic_blocks - 1 then
27972731
match struct_ops_main_registration, List.rev block.instructions with
27982732
| Some registration, { instr_desc = IRReturn (Some return_val); _ } :: rest_rev
2799-
when ir_value_variable_name return_val = Some registration.result_name ->
2733+
when ir_value_variable_name return_val = Some registration.terminal_return_name ->
28002734
List.rev rest_rev
28012735
| _ -> block.instructions
28022736
else
@@ -2811,23 +2745,10 @@ let generate_c_function_from_ir ?(global_variables = []) ?(base_name = "") ?(con
28112745
let body_c =
28122746
let lifecycle_info = match struct_ops_main_registration with
28132747
| Some registration ->
2814-
let result_name = generate_c_value_from_ir ctx registration.result_value in
2815-
Some (body_c, result_name, registration.instance_name, result_name)
2816-
| None ->
2817-
(match ir_multi_prog with
2818-
| Some multi_prog ->
2819-
(match Ir.get_struct_ops_instances multi_prog with
2820-
| [instance] ->
2821-
(match extract_terminal_return_identifier body_c with
2822-
| Some (body_prefix, result_name) ->
2823-
let attach_result_name = match extract_attach_result_identifier body_prefix instance.ir_instance_name with
2824-
| Some name -> name
2825-
| None -> result_name
2826-
in
2827-
Some (body_prefix, result_name, instance.ir_instance_name, attach_result_name)
2828-
| None -> None)
2829-
| _ -> None)
2830-
| None -> None)
2748+
let attach_status_str = generate_c_value_from_ir ctx registration.result_value in
2749+
let result_str = generate_c_value_from_ir ctx registration.terminal_return_value in
2750+
Some (body_c, result_str, registration.instance_name, attach_status_str)
2751+
| None -> None
28312752
in
28322753
match lifecycle_info with
28332754
| Some (body_prefix, result_str, instance_name, attach_status_str) ->
@@ -2960,7 +2881,10 @@ let generate_struct_ops_runtime_helpers base_name ir_multi_program =
29602881
}|} instance_name instance_name instance_name)
29612882
|> String.concat "\n\n"
29622883
in
2963-
sprintf {|%s
2884+
sprintf {|#include <linux/capability.h>
2885+
#include <sys/syscall.h>
2886+
2887+
%s
29642888

29652889
static int bump_memlock_rlimit(void) {
29662890
struct rlimit rlim = {
@@ -2982,15 +2906,28 @@ static int bump_memlock_rlimit(void) {
29822906
return -1;
29832907
}
29842908

2985-
static int ensure_struct_ops_privileges(void) {
2986-
if (geteuid() == 0) {
2909+
/* Check whether the current process has the given effective capability bit.
2910+
Uses the capget(2) syscall directly to avoid a dependency on libcap. */
2911+
static int has_effective_cap(int cap) {
2912+
struct __user_cap_header_struct hdr = {
2913+
.version = _LINUX_CAPABILITY_VERSION_3,
2914+
.pid = 0,
2915+
};
2916+
struct __user_cap_data_struct data[2] = {};
2917+
if (syscall(__NR_capget, &hdr, data) != 0)
29872918
return 0;
2988-
}
2919+
return !!(data[cap >> 5].effective & (1U << (cap & 31)));
2920+
}
29892921

2990-
fprintf(stderr, "Warning: struct_ops loading typically requires root privileges or CAP_BPF/CAP_SYS_ADMIN.\n");
2991-
fprintf(stderr, "Continuing anyway; loading may still succeed if this process has the required capabilities.\n");
2992-
fprintf(stderr, "If it fails with a permission error, try: sudo ./%s\n");
2993-
return 0;
2922+
static int ensure_struct_ops_privileges(void) {
2923+
/* struct_ops loading requires either root or CAP_BPF (39) / CAP_SYS_ADMIN (21). */
2924+
if (geteuid() == 0 ||
2925+
has_effective_cap(39) ||
2926+
has_effective_cap(21))
2927+
return 0;
2928+
fprintf(stderr, "Error: struct_ops loading requires root or CAP_BPF/CAP_SYS_ADMIN.\n");
2929+
fprintf(stderr, "Try running as root: sudo ./%s\n");
2930+
return -1;
29942931
}
29952932

29962933
static void cleanup_%s(void) {

0 commit comments

Comments
 (0)