Merge pull request #292 from AikidoSec/retry-malware-db-download

Retry downloading the malware database 3 times
This commit is contained in:
bitterpanda 2026-01-14 15:58:39 +01:00 committed by GitHub
commit 5898fc851a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 210 additions and 27 deletions

View file

@ -1,5 +1,10 @@
import fetch from "make-fetch-happen";
import { getEcoSystem, ECOSYSTEM_JS, ECOSYSTEM_PY } from "../config/settings.js";
import {
getEcoSystem,
ECOSYSTEM_JS,
ECOSYSTEM_PY,
} from "../config/settings.js";
import { ui } from "../environment/userInteraction.js";
const malwareDatabaseUrls = {
[ECOSYSTEM_JS]: "https://malware-list.aikido.dev/malware_predictions.json",
@ -17,11 +22,19 @@ const malwareDatabaseUrls = {
* @returns {Promise<{malwareDatabase: MalwarePackage[], version: string | undefined}>}
*/
export async function fetchMalwareDatabase() {
const numberOfAttempts = 4;
return retry(async () => {
const ecosystem = getEcoSystem();
const malwareDatabaseUrl = malwareDatabaseUrls[/** @type {keyof typeof malwareDatabaseUrls} */ (ecosystem)];
const malwareDatabaseUrl =
malwareDatabaseUrls[
/** @type {keyof typeof malwareDatabaseUrls} */ (ecosystem)
];
const response = await fetch(malwareDatabaseUrl);
if (!response.ok) {
throw new Error(`Error fetching ${ecosystem} malware database: ${response.statusText}`);
throw new Error(
`Error fetching ${ecosystem} malware database: ${response.statusText}`
);
}
try {
@ -33,14 +46,21 @@ export async function fetchMalwareDatabase() {
} catch (/** @type {any} */ error) {
throw new Error(`Error parsing malware database: ${error.message}`);
}
}, numberOfAttempts);
}
/**
* @returns {Promise<string | undefined>}
*/
export async function fetchMalwareDatabaseVersion() {
const numberOfAttempts = 4;
return retry(async () => {
const ecosystem = getEcoSystem();
const malwareDatabaseUrl = malwareDatabaseUrls[/** @type {keyof typeof malwareDatabaseUrls} */ (ecosystem)];
const malwareDatabaseUrl =
malwareDatabaseUrls[
/** @type {keyof typeof malwareDatabaseUrls} */ (ecosystem)
];
const response = await fetch(malwareDatabaseUrl, {
method: "HEAD",
});
@ -51,4 +71,42 @@ export async function fetchMalwareDatabaseVersion() {
);
}
return response.headers.get("etag") || undefined;
}, numberOfAttempts);
}
/**
* Retries an asynchronous function multiple times until it succeeds or exhausts all attempts.
*
* @template T
* @param {() => Promise<T>} func - The asynchronous function to retry
* @param {number} attempts - The number of attempts
* @returns {Promise<T>} The return value of the function if successful
* @throws {Error} The last error encountered if all retry attempts fail
*/
async function retry(func, attempts) {
let lastError;
for (let i = 0; i < attempts; i++) {
try {
return await func();
} catch (error) {
ui.writeVerbose(
"An error occurred while trying to download the Aikido Malware database",
error
);
lastError = error;
}
if (i < attempts - 1) {
// When this is not the last try, back-off exponentially:
// 1st attempt - 500ms delay
// 2nd attempt - 1000ms delay
// 3rd attempt - 2000ms delay
// 4th attempt - 4000ms delay
// ...
await new Promise((resolve) => setTimeout(resolve, Math.pow(2, i) * 500));
}
}
throw lastError;
}

View file

@ -0,0 +1,125 @@
import { describe, it, mock, beforeEach } from "node:test";
import assert from "node:assert";
describe("aikido API", async () => {
const mockFetch = mock.fn();
mock.module("make-fetch-happen", {
defaultExport: mockFetch,
});
mock.module("../config/settings.js", {
namedExports: {
getEcoSystem: () => "js",
ECOSYSTEM_JS: "js",
ECOSYSTEM_PY: "py",
},
});
const { fetchMalwareDatabase, fetchMalwareDatabaseVersion } =
await import("./aikido.js");
beforeEach(() => {
mockFetch.mock.resetCalls();
});
describe("fetchMalwareDatabase", () => {
it("should succeed immediately when fetch succeeds on first try", async () => {
const malwareData = [
{ package_name: "malicious-pkg", version: "1.0.0", reason: "test" },
];
mockFetch.mock.mockImplementationOnce(() => ({
ok: true,
json: async () => malwareData,
headers: { get: () => '"etag-123"' },
}));
const result = await fetchMalwareDatabase();
assert.strictEqual(mockFetch.mock.calls.length, 1);
assert.deepStrictEqual(result.malwareDatabase, malwareData);
assert.strictEqual(result.version, '"etag-123"');
});
it("should throw error after exhausting all retries", async () => {
mockFetch.mock.mockImplementation(() => {
throw new Error("Network error");
});
await assert.rejects(() => fetchMalwareDatabase(), {
message: "Network error",
});
assert.strictEqual(mockFetch.mock.calls.length, 4);
});
it("should succeed after failing 3 times and succeeding on 4th attempt", async () => {
const malwareData = [
{ package_name: "bad-pkg", version: "2.0.0", reason: "malware" },
];
let callCount = 0;
mockFetch.mock.mockImplementation(() => {
callCount++;
if (callCount < 4) {
throw new Error("Network error");
}
return {
ok: true,
json: async () => malwareData,
headers: { get: () => '"etag-456"' },
};
});
const result = await fetchMalwareDatabase();
assert.strictEqual(mockFetch.mock.calls.length, 4);
assert.deepStrictEqual(result.malwareDatabase, malwareData);
assert.strictEqual(result.version, '"etag-456"');
});
});
describe("fetchMalwareDatabaseVersion", () => {
it("should succeed immediately when fetch succeeds on first try", async () => {
mockFetch.mock.mockImplementationOnce(() => ({
ok: true,
headers: { get: () => '"version-etag"' },
}));
const result = await fetchMalwareDatabaseVersion();
assert.strictEqual(mockFetch.mock.calls.length, 1);
assert.strictEqual(result, '"version-etag"');
});
it("should throw error after exhausting all retries", async () => {
mockFetch.mock.mockImplementation(() => {
throw new Error("Connection refused");
});
await assert.rejects(() => fetchMalwareDatabaseVersion(), {
message: "Connection refused",
});
assert.strictEqual(mockFetch.mock.calls.length, 4);
});
it("should succeed after failing 3 times and succeeding on 4th attempt", async () => {
let callCount = 0;
mockFetch.mock.mockImplementation(() => {
callCount++;
if (callCount < 4) {
throw new Error("Timeout");
}
return {
ok: true,
headers: { get: () => '"final-etag"' },
};
});
const result = await fetchMalwareDatabaseVersion();
assert.strictEqual(mockFetch.mock.calls.length, 4);
assert.strictEqual(result, '"final-etag"');
});
});
});