diff --git a/apps/supervisor/src/services/computeSnapshotService.ts b/apps/supervisor/src/services/computeSnapshotService.ts index 7206f57fb7..041e2902c7 100644 --- a/apps/supervisor/src/services/computeSnapshotService.ts +++ b/apps/supervisor/src/services/computeSnapshotService.ts @@ -80,11 +80,13 @@ export class ComputeSnapshotService { /** Handle the callback from the gateway after a snapshot completes or fails. */ async handleCallback(body: SnapshotCallbackPayload) { + const snapshotId = body.status === "completed" ? body.snapshot_id : undefined; + this.logger.debug("Snapshot callback", { - snapshotId: body.snapshot_id, + snapshotId, instanceId: body.instance_id, status: body.status, - error: body.error, + error: body.status === "failed" ? body.error : undefined, metadata: body.metadata, durationMs: body.duration_ms, }); @@ -97,7 +99,7 @@ export class ComputeSnapshotService { return { ok: false as const, status: 400 }; } - this.#emitSnapshotSpan(runId, body.duration_ms, body.snapshot_id); + this.#emitSnapshotSpan(runId, body.duration_ms, snapshotId); if (body.status === "completed") { const result = await this.workerClient.submitSuspendCompletion({ diff --git a/apps/webapp/app/presenters/v3/RegionsPresenter.server.ts b/apps/webapp/app/presenters/v3/RegionsPresenter.server.ts index f72b8d2fc5..55bd30e33b 100644 --- a/apps/webapp/app/presenters/v3/RegionsPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/RegionsPresenter.server.ts @@ -2,6 +2,7 @@ import { type Project } from "~/models/project.server"; import { type User } from "~/models/user.server"; import { FEATURE_FLAG } from "~/v3/featureFlags"; import { makeFlag } from "~/v3/featureFlags.server"; +import { defaultVisibilityFilter, resolveComputeAccess } from "~/v3/regionAccess.server"; import { BasePresenter } from "./basePresenter.server"; import { getCurrentPlan } from "~/services/platform.v3.server"; @@ -32,6 +33,9 @@ export class RegionsPresenter extends BasePresenter { organizationId: true, defaultWorkerGroupId: true, allowedWorkerQueues: true, + organization: { + select: { featureFlags: true }, + }, }, where: { slug: projectSlug, @@ -58,6 +62,11 @@ export class RegionsPresenter extends BasePresenter { throw new Error("Default worker instance group not found"); } + const hasComputeAccess = await resolveComputeAccess( + this._replica, + project.organization.featureFlags + ); + const visibleRegions = await this._replica.workerInstanceGroup.findMany({ select: { id: true, @@ -75,9 +84,7 @@ export class RegionsPresenter extends BasePresenter { ? { masterQueue: { in: project.allowedWorkerQueues }, } - : { - hidden: false, - }, + : defaultVisibilityFilter(hasComputeAccess), orderBy: { name: "asc", }, diff --git a/apps/webapp/app/v3/regionAccess.server.ts b/apps/webapp/app/v3/regionAccess.server.ts new file mode 100644 index 0000000000..c3e338cb94 --- /dev/null +++ b/apps/webapp/app/v3/regionAccess.server.ts @@ -0,0 +1,50 @@ +import { type Prisma, type WorkloadType } from "@trigger.dev/database"; +import { type PrismaClientOrTransaction } from "~/db.server"; +import { FEATURE_FLAG } from "./featureFlags"; +import { makeFlag } from "./featureFlags.server"; + +/** + * Resolves whether an org has compute access based on feature flags. + */ +export async function resolveComputeAccess( + prisma: PrismaClientOrTransaction, + orgFeatureFlags: unknown +): Promise { + const flag = makeFlag(prisma); + return flag({ + key: FEATURE_FLAG.hasComputeAccess, + defaultValue: false, + overrides: (orgFeatureFlags as Record) ?? {}, + }); +} + +/** + * Builds a visibility filter for non-admin, non-allowlisted users. + * Without compute access, MICROVM regions are excluded entirely. + * With compute access, hidden flag works normally (existing behavior). + */ +export function defaultVisibilityFilter( + hasComputeAccess: boolean +): Prisma.WorkerInstanceGroupWhereInput { + if (hasComputeAccess) { + return { hidden: false }; + } + + return { hidden: false, workloadType: { not: "MICROVM" } }; +} + +/** + * Whether a region is accessible given compute access. + * MICROVM regions require compute access; all other types pass through. + */ +export function isComputeRegionAccessible( + region: { workloadType: WorkloadType }, + hasComputeAccess: boolean +): boolean { + if (region.workloadType !== "MICROVM") { + return true; + } + + // Allow access to any MICROVM region if the org has compute access + return hasComputeAccess; +} diff --git a/apps/webapp/app/v3/services/computeTemplateCreation.server.ts b/apps/webapp/app/v3/services/computeTemplateCreation.server.ts index 4daa2667f2..37235aa161 100644 --- a/apps/webapp/app/v3/services/computeTemplateCreation.server.ts +++ b/apps/webapp/app/v3/services/computeTemplateCreation.server.ts @@ -3,11 +3,10 @@ import { machinePresetFromName } from "~/v3/machinePresets.server"; import { env } from "~/env.server"; import { logger } from "~/services/logger.server"; import type { PrismaClientOrTransaction } from "~/db.server"; -import { FEATURE_FLAG } from "~/v3/featureFlags"; -import { makeFlag } from "~/v3/featureFlags.server"; import type { AuthenticatedEnvironment } from "~/services/apiAuth.server"; import { ServiceValidationError } from "./baseService.server"; import { FailDeploymentService } from "./failDeployment.server"; +import { resolveComputeAccess } from "../regionAccess.server"; type TemplateCreationMode = "required" | "shadow" | "skip"; @@ -101,9 +100,7 @@ export class ComputeTemplateCreationService { }, }); - throw new ServiceValidationError( - `Compute template creation failed: ${result.error}` - ); + throw new ServiceValidationError(`Compute template creation failed: ${result.error}`); } logger.info("Compute template created", { @@ -132,16 +129,15 @@ export class ComputeTemplateCreationService { }, }); - if (project?.defaultWorkerGroup?.workloadType === "MICROVM") { + if (!project) { + return "skip"; + } + + if (project.defaultWorkerGroup?.workloadType === "MICROVM") { return "required"; } - const flag = makeFlag(prisma); - const hasComputeAccess = await flag({ - key: FEATURE_FLAG.hasComputeAccess, - defaultValue: false, - overrides: (project?.organization?.featureFlags as Record) ?? {}, - }); + const hasComputeAccess = await resolveComputeAccess(prisma, project.organization.featureFlags); if (hasComputeAccess) { return "shadow"; diff --git a/apps/webapp/app/v3/services/setDefaultRegion.server.ts b/apps/webapp/app/v3/services/setDefaultRegion.server.ts index cada819452..e484b9c434 100644 --- a/apps/webapp/app/v3/services/setDefaultRegion.server.ts +++ b/apps/webapp/app/v3/services/setDefaultRegion.server.ts @@ -1,3 +1,4 @@ +import { isComputeRegionAccessible, resolveComputeAccess } from "~/v3/regionAccess.server"; import { BaseService, ServiceValidationError } from "./baseService.server"; export class SetDefaultRegionService extends BaseService { @@ -24,6 +25,9 @@ export class SetDefaultRegionService extends BaseService { where: { id: projectId, }, + include: { + organization: { select: { featureFlags: true } }, + }, }); if (!project) { @@ -36,8 +40,21 @@ export class SetDefaultRegionService extends BaseService { if (!project.allowedWorkerQueues.includes(workerGroup.masterQueue)) { throw new ServiceValidationError("You're not allowed to set this region as default"); } - } else if (workerGroup.hidden) { - throw new ServiceValidationError("This region is not available to you"); + } else { + if (workerGroup.hidden) { + throw new ServiceValidationError("This region is not available to you"); + } + + if (workerGroup.workloadType === "MICROVM") { + const hasComputeAccess = await resolveComputeAccess( + this._prisma, + project.organization.featureFlags + ); + + if (!isComputeRegionAccessible(workerGroup, hasComputeAccess)) { + throw new ServiceValidationError("This region requires compute access"); + } + } } } diff --git a/apps/webapp/app/v3/services/worker/workerGroupService.server.ts b/apps/webapp/app/v3/services/worker/workerGroupService.server.ts index 6a900a16fd..6a2c19cf24 100644 --- a/apps/webapp/app/v3/services/worker/workerGroupService.server.ts +++ b/apps/webapp/app/v3/services/worker/workerGroupService.server.ts @@ -4,6 +4,7 @@ import { WorkerGroupTokenService } from "./workerGroupTokenService.server"; import { logger } from "~/services/logger.server"; import { FEATURE_FLAG } from "~/v3/featureFlags"; import { makeFlag, makeSetFlag } from "~/v3/featureFlags.server"; +import { isComputeRegionAccessible, resolveComputeAccess } from "~/v3/regionAccess.server"; export class WorkerGroupService extends WithRunEngine { private readonly defaultNamePrefix = "worker_group"; @@ -207,6 +208,7 @@ export class WorkerGroupService extends WithRunEngine { }, include: { defaultWorkerGroup: true, + organization: { select: { featureFlags: true } }, }, }); @@ -243,6 +245,17 @@ export class WorkerGroupService extends WithRunEngine { throw new Error(`The region you specified isn't available to you ("${regionOverride}").`); } + if (workerGroup.workloadType === "MICROVM") { + const hasComputeAccess = await resolveComputeAccess( + this._prisma, + project.organization.featureFlags + ); + + if (!isComputeRegionAccessible(workerGroup, hasComputeAccess)) { + throw new Error(`The region you specified isn't available to you ("${regionOverride}").`); + } + } + return workerGroup; } diff --git a/internal-packages/compute/src/types.ts b/internal-packages/compute/src/types.ts index 296e38b59c..a2aa4c9760 100644 --- a/internal-packages/compute/src/types.ts +++ b/internal-packages/compute/src/types.ts @@ -62,12 +62,20 @@ export const SnapshotRestoreRequestSchema = z.object({ }); export type SnapshotRestoreRequest = z.infer; -export const SnapshotCallbackPayloadSchema = z.object({ - snapshot_id: z.string(), - instance_id: z.string(), - status: z.enum(["completed", "failed"]), - error: z.string().optional(), - metadata: z.record(z.string()).optional(), - duration_ms: z.number().optional(), -}); +export const SnapshotCallbackPayloadSchema = z.discriminatedUnion("status", [ + z.object({ + status: z.literal("completed"), + snapshot_id: z.string(), + instance_id: z.string(), + metadata: z.record(z.string()).optional(), + duration_ms: z.number().optional(), + }), + z.object({ + status: z.literal("failed"), + instance_id: z.string(), + error: z.string().optional(), + metadata: z.record(z.string()).optional(), + duration_ms: z.number().optional(), + }), +]); export type SnapshotCallbackPayload = z.infer;