1
use rustix::{
2
    mm::{MlockAllFlags, mlockall},
3
    process::{Gid, Uid, getgid, getuid},
4
    thread::{
5
        CapabilitySet, CapabilitySets, capabilities, remove_capability_from_bounding_set,
6
        set_capabilities,
7
    },
8
};
9

            
10
// libc wrappers since rustix doesn't expose these in public API
11
fn setresuid(ruid: Uid, euid: Uid, suid: Uid) -> Result<(), rustix::io::Errno> {
12
    let ret = unsafe { libc::setresuid(ruid.as_raw(), euid.as_raw(), suid.as_raw()) };
13
    if ret == 0 {
14
        Ok(())
15
    } else {
16
        Err(rustix::io::Errno::from_raw_os_error(
17
            std::io::Error::last_os_error()
18
                .raw_os_error()
19
                .unwrap_or(libc::EINVAL),
20
        ))
21
    }
22
}
23

            
24
fn setresgid(rgid: Gid, egid: Gid, sgid: Gid) -> Result<(), rustix::io::Errno> {
25
    let ret = unsafe { libc::setresgid(rgid.as_raw(), egid.as_raw(), sgid.as_raw()) };
26
    if ret == 0 {
27
        Ok(())
28
    } else {
29
        Err(rustix::io::Errno::from_raw_os_error(
30
            std::io::Error::last_os_error()
31
                .raw_os_error()
32
                .unwrap_or(libc::EINVAL),
33
        ))
34
    }
35
}
36

            
37
fn setgroups(groups: &[Gid]) -> Result<(), rustix::io::Errno> {
38
    let gids: Vec<libc::gid_t> = groups.iter().map(|g| g.as_raw()).collect();
39
    let ret = unsafe { libc::setgroups(gids.len(), gids.as_ptr()) };
40
    if ret == 0 {
41
        Ok(())
42
    } else {
43
        Err(rustix::io::Errno::from_raw_os_error(
44
            std::io::Error::last_os_error()
45
                .raw_os_error()
46
                .unwrap_or(libc::EINVAL),
47
        ))
48
    }
49
}
50

            
51
fn set_bounding_set(caps: CapabilitySet) -> Result<(), rustix::io::Errno> {
52
    let caps_to_drop = CapabilitySet::all().difference(caps);
53
    for cap in caps_to_drop.iter() {
54
        let _ = remove_capability_from_bounding_set(cap);
55
    }
56
    Ok(())
57
}
58

            
59
#[derive(Debug, PartialEq)]
60
enum CapabilityState {
61
    Full,    // setuid root or root user
62
    Partial, // filesystem-based capabilities
63
    None,
64
}
65

            
66
pub fn drop_unnecessary_capabilities() -> Result<(), rustix::io::Errno> {
67
    // Abort if we can't read capabilities (libcap-ng CAPNG_FAIL behavior)
68
    let caps = capabilities(None).unwrap_or_else(|e| {
69
        tracing::error!("Error getting process capabilities: {:?}, aborting", e);
70
        std::process::exit(1);
71
    });
72

            
73
    let capability_state = {
74
        if caps.permitted.is_empty() && caps.effective.is_empty() {
75
            CapabilityState::None
76
        } else {
77
            let all_caps = caps.effective | caps.permitted | caps.inheritable;
78
            // 10+ capabilities = Full (matches libcap-ng heuristic)
79
            if all_caps.bits().count_ones() >= 10 {
80
                CapabilityState::Full
81
            } else {
82
                CapabilityState::Partial
83
            }
84
        }
85
    };
86

            
87
    match capability_state {
88
        CapabilityState::Full => {
89
            set_capabilities(
90
                None,
91
                CapabilitySets {
92
                    effective: CapabilitySet::IPC_LOCK,
93
                    permitted: CapabilitySet::IPC_LOCK,
94
                    inheritable: CapabilitySet::empty(),
95
                },
96
            )?;
97

            
98
            // Needed so permitted caps survive uid 0 → non-zero transition
99
            if unsafe { libc::prctl(libc::PR_SET_KEEPCAPS, 1, 0, 0, 0) } != 0 {
100
                tracing::warn!("Failed to set PR_SET_KEEPCAPS");
101
            }
102

            
103
            if let Err(err) = set_bounding_set(CapabilitySet::IPC_LOCK) {
104
                tracing::debug!("Could not set bounding set (may not be supported): {}", err);
105
            }
106

            
107
            let uid = getuid();
108
            let gid = getgid();
109

            
110
            setresgid(gid, gid, gid)?;
111
            setgroups(&[])?;
112
            setresuid(uid, uid, uid)?; // Clears effective caps despite keepcaps
113

            
114
            if unsafe { libc::prctl(libc::PR_SET_KEEPCAPS, 0, 0, 0, 0) } != 0 {
115
                tracing::warn!("Failed to clear PR_SET_KEEPCAPS");
116
            }
117

            
118
            // Re-raise from permitted → effective
119
            set_capabilities(
120
                None,
121
                CapabilitySets {
122
                    effective: CapabilitySet::IPC_LOCK,
123
                    permitted: CapabilitySet::IPC_LOCK,
124
                    inheritable: CapabilitySet::empty(),
125
                },
126
            )?;
127
        }
128
        CapabilityState::None => {
129
            tracing::warn!("No process capabilities, insecure memory might get used");
130
            return Ok(());
131
        }
132
        CapabilityState::Partial => {
133
            if !caps.effective.contains(CapabilitySet::IPC_LOCK) {
134
                tracing::warn!("Insufficient process capabilities, insecure memory might get used");
135
            }
136

            
137
            // Clear bounding set if we have CAP_SETPCAP (do this before dropping caps)
138
            if caps.effective.contains(CapabilitySet::SETPCAP)
139
                && let Err(err) = set_bounding_set(CapabilitySet::IPC_LOCK)
140
            {
141
                tracing::warn!("Failed to set bounding set: {}", err);
142
            }
143

            
144
            set_capabilities(
145
                None,
146
                CapabilitySets {
147
                    effective: CapabilitySet::IPC_LOCK,
148
                    permitted: CapabilitySet::IPC_LOCK,
149
                    inheritable: CapabilitySet::empty(),
150
                },
151
            )?;
152
        }
153
    }
154

            
155
    // After dropping capabilities, try to lock memory
156
    // This prevents secrets from being swapped to disk
157
    match mlockall(MlockAllFlags::CURRENT | MlockAllFlags::FUTURE) {
158
        Ok(_) => {
159
            tracing::info!("Successfully locked all memory pages");
160
        }
161
        Err(e) => {
162
            tracing::warn!(
163
                "Failed to lock memory pages (secrets may be swapped to disk): {}",
164
                e
165
            );
166
        }
167
    }
168

            
169
    Ok(())
170
}