From 8bc394dd492286ca2f657901be234a2ec7e834a7 Mon Sep 17 00:00:00 2001 From: hakhandelwal11 Date: Tue, 19 May 2026 16:47:32 +0530 Subject: [PATCH 1/3] feat: add region flag --- pkg/cmd/gpucreate/gpucreate.go | 109 ++++++++++++++++- pkg/cmd/gpusearch/gpusearch.go | 181 +++++++++++++++++----------- pkg/cmd/gpusearch/gpusearch_test.go | 34 +++--- 3 files changed, 233 insertions(+), 91 deletions(-) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go index 7b379dbd..48a3a564 100644 --- a/pkg/cmd/gpucreate/gpucreate.go +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -129,6 +129,7 @@ type CreateResult struct { // searchFilterFlags holds the search filter flag values for create type searchFilterFlags struct { gpuName string + region string provider string minVRAM float64 minTotalVRAM float64 @@ -144,7 +145,7 @@ type searchFilterFlags struct { // hasUserFilters returns true if the user specified any search filter flags func (f *searchFilterFlags) hasUserFilters() bool { - return f.gpuName != "" || f.provider != "" || f.minVRAM > 0 || f.minTotalVRAM > 0 || + return f.gpuName != "" || f.region != "" || f.provider != "" || f.minVRAM > 0 || f.minTotalVRAM > 0 || f.minCapability > 0 || f.minDisk > 0 || f.maxBootTime > 0 || f.stoppable || f.rebootable || f.flexPorts } @@ -233,6 +234,13 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra ComposeFile: composeFile, LaunchableID: launchableID, LaunchableInfo: launchableInfo, + Region: filters.region, + } + + if filters.region != "" { + if err := validateRegionExists(filters.region, gpuCreateStore); err != nil { + return err + } } opts.InstanceTypes, err = resolveInstanceTypes(cmd, gpuCreateStore, opts, types, &filters) @@ -240,6 +248,12 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra return err } + if filters.region != "" { + if err := validateRegion(filters.region, opts.InstanceTypes, gpuCreateStore); err != nil { + return err + } + } + if dryRun { return runDryRun(t, gpuCreateStore, opts.InstanceTypes, &filters) } @@ -281,6 +295,7 @@ func registerCreateFlags(cmd *cobra.Command, name, instanceTypes *string, count, cmd.Flags().StringVar(containerImage, "container-image", "", "Container image URL (required for container mode)") cmd.Flags().StringVar(composeFile, "compose-file", "", "Docker compose file path or URL (required for compose mode)") cmd.Flags().StringVarP(launchable, "launchable", "l", "", "Launchable ID or URL to deploy (e.g., env-XXX or console URL)") + cmd.Flags().StringVarP(&filters.region, "region", "r", "", "Region/location to deploy the instance (e.g., us-east-1, us-central1)") cmd.Flags().StringVarP(&filters.gpuName, "gpu-name", "g", "", "Filter by GPU name (e.g., A100, H100)") cmd.Flags().StringVar(&filters.provider, "provider", "", "Filter by provider/cloud (e.g., aws, gcp)") @@ -318,6 +333,7 @@ type GPUCreateOptions struct { ComposeFile string LaunchableID string LaunchableInfo *store.LaunchableResponse // populated when LaunchableID is set + Region string } // parseLaunchableID extracts a launchable ID from either a raw ID (env-XXX) or @@ -519,7 +535,7 @@ func searchInstances(s GPUCreateStore, filters *searchFilterFlags) ([]gpusearch. } instances := gpusearch.ProcessInstances(response.Items) - filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.provider, "", filters.minVRAM, + filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.region, filters.provider, "", filters.minVRAM, minTotalVRAM, minCapability, 0, minDisk, 0, maxBootTime, filters.stoppable, filters.rebootable, filters.flexPorts, true) gpusearch.SortInstances(filtered, sortBy, filters.descending) @@ -565,6 +581,91 @@ func runDryRun(t *terminal.Terminal, s GPUCreateStore, specs []InstanceSpec, fil return nil } +// validateRegionExists checks that the given region appears in at least one instance type in the catalog. +func validateRegionExists(region string, store GPUCreateStore) error { + response, err := store.GetInstanceTypes(false) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if response == nil || len(response.Items) == 0 { + return nil + } + + regionLower := strings.ToLower(region) + for _, item := range response.Items { + if typeSupportsRegion(item, regionLower) { + return nil + } + } + + return breverrors.NewValidationError( + fmt.Sprintf("region %q is not offered by any instance type -- use 'brev search --json' to list valid regions", + region), + ) +} + +// validateRegion checks that every requested instance type is available in the given region. +func validateRegion(region string, types []InstanceSpec, store GPUCreateStore) error { + if len(types) == 0 { + return nil + } + + response, err := store.GetInstanceTypes(false) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if response == nil || len(response.Items) == 0 { + return nil + } + + catalog := make(map[string]gpusearch.InstanceType, len(response.Items)) + for _, item := range response.Items { + catalog[item.Type] = item + } + + regionLower := strings.ToLower(region) + var unsupported []string + var unknown []string + + for _, spec := range types { + item, ok := catalog[spec.Type] + if !ok { + unknown = append(unknown, spec.Type) + continue + } + if !typeSupportsRegion(item, regionLower) { + unsupported = append(unsupported, spec.Type) + } + } + + if len(unknown) > 0 { + return breverrors.NewValidationError( + fmt.Sprintf("unknown instance type(s) %s -- use 'brev search' to list available types", strings.Join(unknown, ", ")), + ) + } + if len(unsupported) > 0 { + return breverrors.NewValidationError( + fmt.Sprintf("region %q is not available for instance type(s) %s -- use 'brev search --region %s' to find compatible types", + region, strings.Join(unsupported, ", "), region), + ) + } + return nil +} + +// typeSupportsRegion reports whether an instance type lists the given region (already lowercased) +// in either its primary Location or AvailableLocations, using substring matching. +func typeSupportsRegion(item gpusearch.InstanceType, regionLower string) bool { + if strings.Contains(strings.ToLower(item.Location), regionLower) { + return true + } + for _, loc := range item.AvailableLocations { + if strings.Contains(strings.ToLower(loc), regionLower) { + return true + } + } + return false +} + // orDefault returns val if it's non-zero, otherwise returns def func orDefault(val, def float64) float64 { if val > 0 { @@ -1047,6 +1148,10 @@ func (c *createContext) createWorkspace(name string, spec InstanceSpec) (*entity } } + if c.opts.Region != "" { + cwOptions.Location = c.opts.Region + } + // Apply launchable config or build mode if c.opts.LaunchableID != "" { applyLaunchableConfig(cwOptions, c.opts.LaunchableID, c.opts.LaunchableInfo) diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go index 05943141..84b96ebd 100644 --- a/pkg/cmd/gpusearch/gpusearch.go +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -132,6 +132,10 @@ Features column shows instance capabilities: # Filter by GPU name (case-insensitive, partial match) brev search gpu --gpu-name A100 + # Filter by region/location + brev search gpu --region us-east-1 + brev search gpu --region us-central + # Filter by minimum VRAM per GPU (in GB) brev search gpu --min-vram 24 @@ -150,6 +154,9 @@ Features column shows instance capabilities: # Filter by provider brev search cpu --provider aws + # Filter by region/location + brev search cpu --region us-east-1 + # Filter by minimum RAM brev search cpu --min-ram 64 @@ -163,6 +170,7 @@ Features column shows instance capabilities: // sharedFlags holds flags shared between gpu and cpu subcommands type sharedFlags struct { + region string provider string arch string minVCPU int @@ -179,6 +187,7 @@ type sharedFlags struct { // addSharedFlags adds common flags to a command func addSharedFlags(cmd *cobra.Command, f *sharedFlags) { + cmd.Flags().StringVarP(&f.region, "region", "r", "", "Filter by region/location (case-insensitive, partial match, e.g., us-east-1, us-central1)") cmd.Flags().StringVarP(&f.provider, "provider", "p", "", "Filter by provider/cloud (case-insensitive, partial match)") cmd.Flags().StringVar(&f.arch, "arch", "", "Filter by architecture (e.g., x86_64, arm64)") cmd.Flags().IntVar(&f.minVCPU, "min-vcpu", 0, "Minimum number of vCPUs") @@ -213,7 +222,7 @@ func NewCmdGPUSearch(t *terminal.Terminal, store GPUSearchStore) *cobra.Command Example: gpuExample, RunE: func(cmd *cobra.Command, args []string) error { // Default behavior: GPU search - return RunGPUSearch(t, store, gpuName, shared.provider, shared.arch, minVRAM, minTotalVRAM, minCapability, shared.minRAM, shared.minDisk, shared.minVCPU, shared.maxBootTime, shared.stoppable, shared.rebootable, shared.flexPorts, shared.sortBy, shared.descending, shared.jsonOutput, wide) + return RunGPUSearch(t, store, gpuName, shared.region, shared.provider, shared.arch, minVRAM, minTotalVRAM, minCapability, shared.minRAM, shared.minDisk, shared.minVCPU, shared.maxBootTime, shared.stoppable, shared.rebootable, shared.flexPorts, shared.sortBy, shared.descending, shared.jsonOutput, wide) }, } @@ -247,7 +256,7 @@ func newCmdGPUSubcommand(t *terminal.Terminal, store GPUSearchStore) *cobra.Comm Short: "Search GPU instance types", Example: gpuExample, RunE: func(cmd *cobra.Command, args []string) error { - return RunGPUSearch(t, store, gpuName, shared.provider, shared.arch, minVRAM, minTotalVRAM, minCapability, shared.minRAM, shared.minDisk, shared.minVCPU, shared.maxBootTime, shared.stoppable, shared.rebootable, shared.flexPorts, shared.sortBy, shared.descending, shared.jsonOutput, wide) + return RunGPUSearch(t, store, gpuName, shared.region, shared.provider, shared.arch, minVRAM, minTotalVRAM, minCapability, shared.minRAM, shared.minDisk, shared.minVCPU, shared.maxBootTime, shared.stoppable, shared.rebootable, shared.flexPorts, shared.sortBy, shared.descending, shared.jsonOutput, wide) }, } @@ -271,7 +280,7 @@ func newCmdCPUSubcommand(t *terminal.Terminal, store GPUSearchStore) *cobra.Comm Short: "Search CPU-only instance types", Example: cpuExample, RunE: func(cmd *cobra.Command, args []string) error { - return RunCPUSearch(t, store, shared.provider, shared.arch, shared.minRAM, shared.minDisk, shared.minVCPU, shared.maxBootTime, shared.stoppable, shared.rebootable, shared.flexPorts, shared.sortBy, shared.descending, shared.jsonOutput) + return RunCPUSearch(t, store, shared.region, shared.provider, shared.arch, shared.minRAM, shared.minDisk, shared.minVCPU, shared.maxBootTime, shared.stoppable, shared.rebootable, shared.flexPorts, shared.sortBy, shared.descending, shared.jsonOutput) }, } @@ -282,28 +291,30 @@ func newCmdCPUSubcommand(t *terminal.Terminal, store GPUSearchStore) *cobra.Comm // GPUInstanceInfo holds processed GPU instance information for display type GPUInstanceInfo struct { - Type string `json:"type"` - Cloud string `json:"cloud"` // Underlying cloud (e.g., hyperstack, aws, gcp) - Provider string `json:"provider"` // Provider/aggregator (e.g., shadeform, aws, gcp) - GPUName string `json:"gpu_name"` - GPUCount int `json:"gpu_count"` - VRAMPerGPU float64 `json:"vram_per_gpu_gb"` - TotalVRAM float64 `json:"total_vram_gb"` - Capability float64 `json:"capability"` - VCPUs int `json:"vcpus"` - Memory string `json:"memory"` - RAMInGB float64 `json:"ram_gb"` - Arch string `json:"arch"` - DiskMin float64 `json:"disk_min_gb"` - DiskMax float64 `json:"disk_max_gb"` - DiskPricePerMo float64 `json:"disk_price_per_gb_mo,omitempty"` // $/GB/month for flexible storage - BootTime int `json:"boot_time_seconds"` - Stoppable bool `json:"stoppable"` - Rebootable bool `json:"rebootable"` - FlexPorts bool `json:"flex_ports"` - TargetDisk float64 `json:"target_disk_gb,omitempty"` - PricePerHour float64 `json:"price_per_hour"` - Manufacturer string `json:"-"` // exclude from JSON output + Type string `json:"type"` + Cloud string `json:"cloud"` // Underlying cloud (e.g., hyperstack, aws, gcp) + Provider string `json:"provider"` // Provider/aggregator (e.g., shadeform, aws, gcp) + GPUName string `json:"gpu_name"` + GPUCount int `json:"gpu_count"` + VRAMPerGPU float64 `json:"vram_per_gpu_gb"` + TotalVRAM float64 `json:"total_vram_gb"` + Capability float64 `json:"capability"` + VCPUs int `json:"vcpus"` + Memory string `json:"memory"` + RAMInGB float64 `json:"ram_gb"` + Arch string `json:"arch"` + Region string `json:"region,omitempty"` + AvailableRegions []string `json:"available_regions,omitempty"` + DiskMin float64 `json:"disk_min_gb"` + DiskMax float64 `json:"disk_max_gb"` + DiskPricePerMo float64 `json:"disk_price_per_gb_mo,omitempty"` // $/GB/month for flexible storage + BootTime int `json:"boot_time_seconds"` + Stoppable bool `json:"stoppable"` + Rebootable bool `json:"rebootable"` + FlexPorts bool `json:"flex_ports"` + TargetDisk float64 `json:"target_disk_gb,omitempty"` + PricePerHour float64 `json:"price_per_hour"` + Manufacturer string `json:"-"` // exclude from JSON output } // IsStdoutPiped returns true if stdout is being piped (not a terminal) @@ -313,7 +324,7 @@ func IsStdoutPiped() bool { } // RunGPUSearch executes the GPU search with filters and sorting -func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName, provider, arch string, minVRAM, minTotalVRAM, minCapability, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts bool, sortBy string, descending, jsonOutput, wide bool) error { +func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName, region, provider, arch string, minVRAM, minTotalVRAM, minCapability, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts bool, sortBy string, descending, jsonOutput, wide bool) error { if err := validateSortOption(sortBy); err != nil { return err } @@ -332,7 +343,7 @@ func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName, provider, instances := ProcessInstances(response.Items) // Filter to GPU-only instances - filtered := FilterInstances(instances, gpuName, provider, arch, minVRAM, minTotalVRAM, minCapability, minRAM, minDisk, minVCPU, maxBootTime, stoppable, rebootable, flexPorts, false) + filtered := FilterInstances(instances, gpuName, region, provider, arch, minVRAM, minTotalVRAM, minCapability, minRAM, minDisk, minVCPU, maxBootTime, stoppable, rebootable, flexPorts, false) if len(filtered) == 0 { return displayEmptyResults(t, "No GPU instances match the specified filters", jsonOutput, piped) @@ -344,7 +355,7 @@ func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName, provider, } // RunCPUSearch executes the CPU search with filters and sorting -func RunCPUSearch(t *terminal.Terminal, store GPUSearchStore, provider, arch string, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts bool, sortBy string, descending, jsonOutput bool) error { +func RunCPUSearch(t *terminal.Terminal, store GPUSearchStore, region, provider, arch string, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts bool, sortBy string, descending, jsonOutput bool) error { if err := validateSortOption(sortBy); err != nil { return err } @@ -363,7 +374,7 @@ func RunCPUSearch(t *terminal.Terminal, store GPUSearchStore, provider, arch str instances := ProcessInstances(response.Items) // Filter to CPU-only instances - filtered := FilterCPUInstances(instances, provider, arch, minRAM, minDisk, minVCPU, maxBootTime, stoppable, rebootable, flexPorts) + filtered := FilterCPUInstances(instances, region, provider, arch, minRAM, minDisk, minVCPU, maxBootTime, stoppable, rebootable, flexPorts) if len(filtered) == 0 { return displayEmptyResults(t, "No CPU instances match the specified filters", jsonOutput, piped) @@ -741,24 +752,26 @@ func ProcessInstances(items []InstanceType) []GPUInstanceInfo { if len(item.SupportedGPUs) == 0 { // CPU-only instance instances = append(instances, GPUInstanceInfo{ - Type: item.Type, - Cloud: extractCloud(item.Type, item.Provider), - Provider: item.Provider, - GPUName: "-", - GPUCount: 0, - VCPUs: item.VCPU, - Memory: item.Memory, - RAMInGB: ramInGB, - Arch: arch, - DiskMin: diskMin, - DiskMax: diskMax, - DiskPricePerMo: diskPricePerMo, - BootTime: bootTime, - Stoppable: item.Stoppable, - Rebootable: item.Rebootable, - FlexPorts: item.CanModifyFirewallRules, - PricePerHour: price, - Manufacturer: "cpu", + Type: item.Type, + Cloud: extractCloud(item.Type, item.Provider), + Provider: item.Provider, + GPUName: "-", + GPUCount: 0, + VCPUs: item.VCPU, + Memory: item.Memory, + RAMInGB: ramInGB, + Arch: arch, + Region: item.Location, + AvailableRegions: item.AvailableLocations, + DiskMin: diskMin, + DiskMax: diskMax, + DiskPricePerMo: diskPricePerMo, + BootTime: bootTime, + Stoppable: item.Stoppable, + Rebootable: item.Rebootable, + FlexPorts: item.CanModifyFirewallRules, + PricePerHour: price, + Manufacturer: "cpu", }) continue } @@ -774,27 +787,29 @@ func ProcessInstances(items []InstanceType) []GPUInstanceInfo { capability := getGPUCapability(gpu.Name) instances = append(instances, GPUInstanceInfo{ - Type: item.Type, - Cloud: extractCloud(item.Type, item.Provider), - Provider: item.Provider, - GPUName: gpu.Name, - GPUCount: gpu.Count, - VRAMPerGPU: vramPerGPU, - TotalVRAM: totalVRAM, - Capability: capability, - VCPUs: item.VCPU, - Memory: item.Memory, - RAMInGB: ramInGB, - Arch: arch, - DiskMin: diskMin, - DiskMax: diskMax, - DiskPricePerMo: diskPricePerMo, - BootTime: bootTime, - Stoppable: item.Stoppable, - Rebootable: item.Rebootable, - FlexPorts: item.CanModifyFirewallRules, - PricePerHour: price, - Manufacturer: gpu.Manufacturer, + Type: item.Type, + Cloud: extractCloud(item.Type, item.Provider), + Provider: item.Provider, + GPUName: gpu.Name, + GPUCount: gpu.Count, + VRAMPerGPU: vramPerGPU, + TotalVRAM: totalVRAM, + Capability: capability, + VCPUs: item.VCPU, + Memory: item.Memory, + RAMInGB: ramInGB, + Arch: arch, + Region: item.Location, + AvailableRegions: item.AvailableLocations, + DiskMin: diskMin, + DiskMax: diskMax, + DiskPricePerMo: diskPricePerMo, + BootTime: bootTime, + Stoppable: item.Stoppable, + Rebootable: item.Rebootable, + FlexPorts: item.CanModifyFirewallRules, + PricePerHour: price, + Manufacturer: gpu.Manufacturer, }) } } @@ -805,6 +820,7 @@ func ProcessInstances(items []InstanceType) []GPUInstanceInfo { // FilterOptions holds all filter criteria for instances type FilterOptions struct { GPUName string + Region string Provider string Arch string MinVRAM float64 @@ -819,7 +835,7 @@ type FilterOptions struct { FlexPorts bool } -// matchesStringFilters checks GPU name and provider filters +// matchesStringFilters checks GPU name, region, provider, and architecture filters func (f *FilterOptions) matchesStringFilters(inst GPUInstanceInfo) bool { // Allow CPU-only instances through; filter out non-NVIDIA GPUs (AMD, Intel/Habana, etc.) if inst.Manufacturer != "cpu" && !strings.Contains(strings.ToUpper(inst.Manufacturer), "NVIDIA") { @@ -829,6 +845,10 @@ func (f *FilterOptions) matchesStringFilters(inst GPUInstanceInfo) bool { if f.GPUName != "" && !strings.Contains(strings.ToLower(inst.GPUName), strings.ToLower(f.GPUName)) { return false } + // Filter by region (case-insensitive partial match against primary location and available locations) + if f.Region != "" && !matchesRegion(f.Region, inst) { + return false + } // Filter by provider (case-insensitive partial match) if f.Provider != "" && !strings.Contains(strings.ToLower(inst.Provider), strings.ToLower(f.Provider)) { return false @@ -840,6 +860,22 @@ func (f *FilterOptions) matchesStringFilters(inst GPUInstanceInfo) bool { return true } +// matchesRegion checks if an instance is available in the given region +func matchesRegion(region string, inst GPUInstanceInfo) bool { + regionLower := strings.ToLower(region) + // Check primary location + if strings.Contains(strings.ToLower(inst.Region), regionLower) { + return true + } + // Check available locations + for _, loc := range inst.AvailableRegions { + if strings.Contains(strings.ToLower(loc), regionLower) { + return true + } + } + return false +} + // matchesNumericFilters checks VRAM, capability, disk, vCPU, and boot time filters func (f *FilterOptions) matchesNumericFilters(inst GPUInstanceInfo) bool { if f.MinVCPU > 0 && inst.VCPUs < f.MinVCPU { @@ -889,9 +925,10 @@ func (f *FilterOptions) matchesFilter(inst GPUInstanceInfo) bool { } // FilterInstances applies all filters to the instance list. When gpuOnly is true, CPU-only instances are excluded. -func FilterInstances(instances []GPUInstanceInfo, gpuName, provider, arch string, minVRAM, minTotalVRAM, minCapability, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts, gpuOnly bool) []GPUInstanceInfo { +func FilterInstances(instances []GPUInstanceInfo, gpuName, region, provider, arch string, minVRAM, minTotalVRAM, minCapability, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts, gpuOnly bool) []GPUInstanceInfo { opts := &FilterOptions{ GPUName: gpuName, + Region: region, Provider: provider, Arch: arch, MinVRAM: minVRAM, @@ -919,7 +956,7 @@ func FilterInstances(instances []GPUInstanceInfo, gpuName, provider, arch string } // FilterCPUInstances filters to CPU-only instances using shared filter logic -func FilterCPUInstances(instances []GPUInstanceInfo, provider, arch string, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts bool) []GPUInstanceInfo { +func FilterCPUInstances(instances []GPUInstanceInfo, region, provider, arch string, minRAM, minDisk float64, minVCPU, maxBootTime int, stoppable, rebootable, flexPorts bool) []GPUInstanceInfo { // Filter out GPU instances first, then apply shared filters var cpuOnly []GPUInstanceInfo for _, inst := range instances { @@ -927,7 +964,7 @@ func FilterCPUInstances(instances []GPUInstanceInfo, provider, arch string, minR cpuOnly = append(cpuOnly, inst) } } - return FilterInstances(cpuOnly, "", provider, arch, 0, 0, 0, minRAM, minDisk, minVCPU, maxBootTime, stoppable, rebootable, flexPorts, false) + return FilterInstances(cpuOnly, "", region, provider, arch, 0, 0, 0, minRAM, minDisk, minVCPU, maxBootTime, stoppable, rebootable, flexPorts, false) } // SortInstances sorts the instance list by the specified column diff --git a/pkg/cmd/gpusearch/gpusearch_test.go b/pkg/cmd/gpusearch/gpusearch_test.go index cf633465..f1a9837c 100644 --- a/pkg/cmd/gpusearch/gpusearch_test.go +++ b/pkg/cmd/gpusearch/gpusearch_test.go @@ -168,19 +168,19 @@ func TestFilterInstancesByGPUName(t *testing.T) { instances := ProcessInstances(response.Items) // Filter by A10G - filtered := FilterInstances(instances, "A10G", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "A10G", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 2, "Should have 2 A10G instances") // Filter by V100 - filtered = FilterInstances(instances, "V100", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "V100", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 2, "Should have 2 V100 instances") // Filter by lowercase (case-insensitive) - filtered = FilterInstances(instances, "v100", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "v100", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 2, "Should have 2 V100 instances (case-insensitive)") // Filter by partial match - filtered = FilterInstances(instances, "A1", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "A1", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 3, "Should have 3 instances matching 'A1' (A10G and A100)") } @@ -189,11 +189,11 @@ func TestFilterInstancesByMinVRAM(t *testing.T) { instances := ProcessInstances(response.Items) // Filter by min VRAM 24GB - filtered := FilterInstances(instances, "", "", "", 24, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "", "", "", "", 24, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 4, "Should have 4 instances with >= 24GB VRAM") // Filter by min VRAM 40GB - filtered = FilterInstances(instances, "", "", "", 40, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "", "", "", "", 40, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 1, "Should have 1 instance with >= 40GB VRAM") assert.Equal(t, "A100", filtered[0].GPUName) } @@ -203,11 +203,11 @@ func TestFilterInstancesByMinTotalVRAM(t *testing.T) { instances := ProcessInstances(response.Items) // Filter by min total VRAM 60GB - filtered := FilterInstances(instances, "", "", "", 0, 60, 0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "", "", "", "", 0, 60, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 2, "Should have 2 instances with >= 60GB total VRAM") // Filter by min total VRAM 300GB - filtered = FilterInstances(instances, "", "", "", 0, 300, 0, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "", "", "", "", 0, 300, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 1, "Should have 1 instance with >= 300GB total VRAM") assert.Equal(t, "p4d.24xlarge", filtered[0].Type) } @@ -217,11 +217,11 @@ func TestFilterInstancesByMinCapability(t *testing.T) { instances := ProcessInstances(response.Items) // Filter by capability >= 8.0 - filtered := FilterInstances(instances, "", "", "", 0, 0, 8.0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "", "", "", "", 0, 0, 8.0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 4, "Should have 4 instances with capability >= 8.0") // Filter by capability >= 8.5 - filtered = FilterInstances(instances, "", "", "", 0, 0, 8.5, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "", "", "", "", 0, 0, 8.5, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 3, "Should have 3 instances with capability >= 8.5") } @@ -230,11 +230,11 @@ func TestFilterInstancesCombined(t *testing.T) { instances := ProcessInstances(response.Items) // Filter by GPU name and min VRAM - filtered := FilterInstances(instances, "A10G", "", "", 24, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "A10G", "", "", "", 24, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 2, "Should have 2 A10G instances with >= 24GB VRAM") // Filter by GPU name, min VRAM, and capability - filtered = FilterInstances(instances, "", "", "", 24, 0, 8.5, 0, 0, 0, 0, false, false, false, true) + filtered = FilterInstances(instances, "", "", "", "", 24, 0, 8.5, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 3, "Should have 3 instances with >= 24GB VRAM and capability >= 8.5") } @@ -336,7 +336,7 @@ func TestEmptyInstanceTypes(t *testing.T) { assert.Len(t, instances, 0, "Should have 0 instances") - filtered := FilterInstances(instances, "A100", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "A100", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 0, "Filtered should also be empty") } @@ -395,12 +395,12 @@ func TestNonGPUInstancesFilteredByDefault(t *testing.T) { instances := ProcessInstances(response.Items) // gpuOnly=true should filter out CPU instances - filtered := FilterInstances(instances, "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + filtered := FilterInstances(instances, "", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, filtered, 1, "gpuOnly should exclude CPU instances") assert.Equal(t, "g5.xlarge", filtered[0].Type) // gpuOnly=false should keep CPU instances - filtered = FilterInstances(instances, "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, false) + filtered = FilterInstances(instances, "", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, false) assert.Len(t, filtered, 2, "Without gpuOnly, both CPU and GPU instances pass") } @@ -464,7 +464,7 @@ func TestFilterByMaxBootTimeExcludesUnknown(t *testing.T) { assert.Len(t, instances, 3, "Should have 3 instances before filtering") // Filter by max boot time of 10 minutes - should exclude unknown and slow-boot - filtered := FilterInstances(instances, "", "", "", 0, 0, 0, 0, 0, 0, 10, false, false, false, true) + filtered := FilterInstances(instances, "", "", "", "", 0, 0, 0, 0, 0, 0, 10, false, false, false, true) assert.Len(t, filtered, 1, "Should have 1 instance with boot time <= 10 minutes") assert.Equal(t, "fast-boot", filtered[0].Type, "Only fast-boot should match") @@ -475,7 +475,7 @@ func TestFilterByMaxBootTimeExcludesUnknown(t *testing.T) { } // Without filter, all instances should be included - noFilter := FilterInstances(instances, "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) + noFilter := FilterInstances(instances, "", "", "", "", 0, 0, 0, 0, 0, 0, 0, false, false, false, true) assert.Len(t, noFilter, 3, "Without filter, all 3 instances should be included") } From b9e1dd3a6e2a9f43efb76fc67b5903a6e1a508e7 Mon Sep 17 00:00:00 2001 From: hakhandelwal11 Date: Wed, 10 Jun 2026 19:02:43 +0530 Subject: [PATCH 2/3] fix: match regions only against AvailableLocations --- pkg/cmd/gpucreate/gpucreate.go | 7 ++----- pkg/cmd/gpusearch/gpusearch.go | 9 +-------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go index 48a3a564..66523c2d 100644 --- a/pkg/cmd/gpucreate/gpucreate.go +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -652,12 +652,9 @@ func validateRegion(region string, types []InstanceSpec, store GPUCreateStore) e return nil } -// typeSupportsRegion reports whether an instance type lists the given region (already lowercased) -// in either its primary Location or AvailableLocations, using substring matching. +// typeSupportsRegion reports whether an instance type lists the given region +// (already lowercased) in its AvailableLocations, using substring matching. func typeSupportsRegion(item gpusearch.InstanceType, regionLower string) bool { - if strings.Contains(strings.ToLower(item.Location), regionLower) { - return true - } for _, loc := range item.AvailableLocations { if strings.Contains(strings.ToLower(loc), regionLower) { return true diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go index 84b96ebd..ba1b51a4 100644 --- a/pkg/cmd/gpusearch/gpusearch.go +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -303,7 +303,6 @@ type GPUInstanceInfo struct { Memory string `json:"memory"` RAMInGB float64 `json:"ram_gb"` Arch string `json:"arch"` - Region string `json:"region,omitempty"` AvailableRegions []string `json:"available_regions,omitempty"` DiskMin float64 `json:"disk_min_gb"` DiskMax float64 `json:"disk_max_gb"` @@ -761,7 +760,6 @@ func ProcessInstances(items []InstanceType) []GPUInstanceInfo { Memory: item.Memory, RAMInGB: ramInGB, Arch: arch, - Region: item.Location, AvailableRegions: item.AvailableLocations, DiskMin: diskMin, DiskMax: diskMax, @@ -799,7 +797,6 @@ func ProcessInstances(items []InstanceType) []GPUInstanceInfo { Memory: item.Memory, RAMInGB: ramInGB, Arch: arch, - Region: item.Location, AvailableRegions: item.AvailableLocations, DiskMin: diskMin, DiskMax: diskMax, @@ -861,13 +858,9 @@ func (f *FilterOptions) matchesStringFilters(inst GPUInstanceInfo) bool { } // matchesRegion checks if an instance is available in the given region +// via substring match against any entry in AvailableRegions. func matchesRegion(region string, inst GPUInstanceInfo) bool { regionLower := strings.ToLower(region) - // Check primary location - if strings.Contains(strings.ToLower(inst.Region), regionLower) { - return true - } - // Check available locations for _, loc := range inst.AvailableRegions { if strings.Contains(strings.ToLower(loc), regionLower) { return true From 066fc512ef8e25edc7dc4a68bed39c13df3e565a Mon Sep 17 00:00:00 2001 From: hakhandelwal11 Date: Wed, 1 Jul 2026 18:46:21 +0530 Subject: [PATCH 3/3] fix: deduplicate catalog fetches and normalize --region case --- pkg/cmd/gpucreate/gpucreate.go | 122 ++++++++++++++-------------- pkg/cmd/gpucreate/gpucreate_test.go | 58 +++++++++++-- 2 files changed, 114 insertions(+), 66 deletions(-) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go index 66523c2d..0c74db8b 100644 --- a/pkg/cmd/gpucreate/gpucreate.go +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -181,6 +181,12 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra name = args[0] } + // Normalize --region: provider region names are case-sensitive on the + // server side (e.g. GCP rejects "US-WEST1", accepts "us-west1"), while + // our client-side validation is case-insensitive. Lowercase here so + // what we validate matches what we send. + filters.region = strings.ToLower(strings.TrimSpace(filters.region)) + launchableID, err := parseLaunchableID(launchable) if err != nil { return err @@ -237,25 +243,41 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra Region: filters.region, } + // Fetch the instance-type catalog at most once per invocation. Region + // validation, auto-search, and dry-run preview all share the same response + // so we never issue more than one round-trip. --type explicit + no region + // skips the fetch entirely (runDryRun returns early on explicit specs). + var catalogItems []gpusearch.InstanceType + needCatalog := filters.region != "" || (len(types) == 0 && launchableID == "") + if needCatalog { + resp, err := gpuCreateStore.GetInstanceTypes(false) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if resp != nil { + catalogItems = resp.Items + } + } + if filters.region != "" { - if err := validateRegionExists(filters.region, gpuCreateStore); err != nil { + if err := validateRegionExists(filters.region, catalogItems); err != nil { return err } } - opts.InstanceTypes, err = resolveInstanceTypes(cmd, gpuCreateStore, opts, types, &filters) + opts.InstanceTypes, err = resolveInstanceTypes(cmd, catalogItems, opts, types, &filters) if err != nil { return err } if filters.region != "" { - if err := validateRegion(filters.region, opts.InstanceTypes, gpuCreateStore); err != nil { + if err := validateRegion(filters.region, opts.InstanceTypes, catalogItems); err != nil { return err } } if dryRun { - return runDryRun(t, gpuCreateStore, opts.InstanceTypes, &filters) + return runDryRun(t, catalogItems, opts.InstanceTypes, &filters) } return RunGPUCreate(t, gpuCreateStore, opts) @@ -463,8 +485,9 @@ func launchableBuildModeName(info *store.LaunchableResponse) string { } } -// resolveInstanceTypes determines instance types from launchable, flags, or filters -func resolveInstanceTypes(cmd *cobra.Command, gpuCreateStore GPUCreateStore, opts GPUCreateOptions, types []InstanceSpec, filters *searchFilterFlags) ([]InstanceSpec, error) { +// resolveInstanceTypes determines instance types from launchable, flags, or filters. +// Operates on a pre-fetched catalog when auto-search is needed. +func resolveInstanceTypes(cmd *cobra.Command, items []gpusearch.InstanceType, opts GPUCreateOptions, types []InstanceSpec, filters *searchFilterFlags) ([]InstanceSpec, error) { if opts.LaunchableID != "" && len(types) == 0 && !cmd.Flags().Changed("type") { instanceType := "" if opts.LaunchableInfo != nil { @@ -477,10 +500,7 @@ func resolveInstanceTypes(cmd *cobra.Command, gpuCreateStore GPUCreateStore, opt } if len(types) == 0 { - filtered, err := getFilteredInstanceTypes(gpuCreateStore, filters) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } + filtered := getFilteredInstanceTypes(items, filters) if len(filtered) == 0 { return nil, breverrors.NewValidationError("no GPU instances match the specified filters. Try 'brev search' to see available options") } @@ -511,15 +531,11 @@ func parseStartupScript(value string) (string, error) { return value, nil } -// searchInstances fetches and filters GPU instances using user-provided filters merged with defaults -func searchInstances(s GPUCreateStore, filters *searchFilterFlags) ([]gpusearch.GPUInstanceInfo, float64, error) { - response, err := s.GetInstanceTypes(false) - if err != nil { - return nil, 0, breverrors.WrapAndTrace(err) - } - - if response == nil || len(response.Items) == 0 { - return nil, 0, nil +// searchInstances filters GPU instances using user-provided filters merged with defaults. +// Operates on a pre-fetched catalog so callers can share one network round-trip. +func searchInstances(items []gpusearch.InstanceType, filters *searchFilterFlags) ([]gpusearch.GPUInstanceInfo, float64) { + if len(items) == 0 { + return nil, 0 } minTotalVRAM := orDefault(filters.minTotalVRAM, defaultMinTotalVRAM) @@ -534,21 +550,19 @@ func searchInstances(s GPUCreateStore, filters *searchFilterFlags) ([]gpusearch. sortBy = "price" } - instances := gpusearch.ProcessInstances(response.Items) + instances := gpusearch.ProcessInstances(items) filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.region, filters.provider, "", filters.minVRAM, minTotalVRAM, minCapability, 0, minDisk, 0, maxBootTime, filters.stoppable, filters.rebootable, filters.flexPorts, true) gpusearch.SortInstances(filtered, sortBy, filters.descending) - return filtered, minDisk, nil + return filtered, minDisk } -// getFilteredInstanceTypes fetches GPU instance types using user-provided filters +// getFilteredInstanceTypes filters GPU instance types using user-provided filters // merged with defaults. When a filter flag is not set, the default value is used. -func getFilteredInstanceTypes(s GPUCreateStore, filters *searchFilterFlags) ([]InstanceSpec, error) { - filtered, minDisk, err := searchInstances(s, filters) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } +// Operates on a pre-fetched catalog. +func getFilteredInstanceTypes(items []gpusearch.InstanceType, filters *searchFilterFlags) []InstanceSpec { + filtered, minDisk := searchInstances(items, filters) var specs []InstanceSpec for _, inst := range filtered { @@ -559,20 +573,18 @@ func getFilteredInstanceTypes(s GPUCreateStore, filters *searchFilterFlags) ([]I specs = append(specs, InstanceSpec{Type: inst.Type, DiskGB: diskGB}) } - return specs, nil + return specs } -// runDryRun shows the instance types that would be used without creating anything -func runDryRun(t *terminal.Terminal, s GPUCreateStore, specs []InstanceSpec, filters *searchFilterFlags) error { +// runDryRun shows the instance types that would be used without creating anything. +// Operates on a pre-fetched catalog. +func runDryRun(t *terminal.Terminal, items []gpusearch.InstanceType, specs []InstanceSpec, filters *searchFilterFlags) error { if len(specs) > 0 { t.Print(formatInstanceSpecs(specs)) return nil } - filtered, _, err := searchInstances(s, filters) - if err != nil { - return breverrors.WrapAndTrace(err) - } + filtered, _ := searchInstances(items, filters) piped := gpusearch.IsStdoutPiped() if err := gpusearch.DisplayGPUResults(t, filtered, false, piped, false); err != nil { @@ -582,18 +594,14 @@ func runDryRun(t *terminal.Terminal, s GPUCreateStore, specs []InstanceSpec, fil } // validateRegionExists checks that the given region appears in at least one instance type in the catalog. -func validateRegionExists(region string, store GPUCreateStore) error { - response, err := store.GetInstanceTypes(false) - if err != nil { - return breverrors.WrapAndTrace(err) - } - if response == nil || len(response.Items) == 0 { +// Operates on a pre-fetched catalog. +func validateRegionExists(region string, items []gpusearch.InstanceType) error { + if len(items) == 0 { return nil } - regionLower := strings.ToLower(region) - for _, item := range response.Items { - if typeSupportsRegion(item, regionLower) { + for _, item := range items { + if typeSupportsRegion(item, region) { return nil } } @@ -605,25 +613,17 @@ func validateRegionExists(region string, store GPUCreateStore) error { } // validateRegion checks that every requested instance type is available in the given region. -func validateRegion(region string, types []InstanceSpec, store GPUCreateStore) error { - if len(types) == 0 { - return nil - } - - response, err := store.GetInstanceTypes(false) - if err != nil { - return breverrors.WrapAndTrace(err) - } - if response == nil || len(response.Items) == 0 { +// Operates on a pre-fetched catalog. +func validateRegion(region string, types []InstanceSpec, items []gpusearch.InstanceType) error { + if len(types) == 0 || len(items) == 0 { return nil } - catalog := make(map[string]gpusearch.InstanceType, len(response.Items)) - for _, item := range response.Items { + catalog := make(map[string]gpusearch.InstanceType, len(items)) + for _, item := range items { catalog[item.Type] = item } - regionLower := strings.ToLower(region) var unsupported []string var unknown []string @@ -633,7 +633,7 @@ func validateRegion(region string, types []InstanceSpec, store GPUCreateStore) e unknown = append(unknown, spec.Type) continue } - if !typeSupportsRegion(item, regionLower) { + if !typeSupportsRegion(item, region) { unsupported = append(unsupported, spec.Type) } } @@ -653,10 +653,12 @@ func validateRegion(region string, types []InstanceSpec, store GPUCreateStore) e } // typeSupportsRegion reports whether an instance type lists the given region -// (already lowercased) in its AvailableLocations, using substring matching. -func typeSupportsRegion(item gpusearch.InstanceType, regionLower string) bool { +// in its AvailableLocations. Match is case-insensitive but must be an exact +// equality — substring matching would let e.g. --region "us" pass validation +// against "us-east-1", and then we'd forward "us" to the server as the region. +func typeSupportsRegion(item gpusearch.InstanceType, region string) bool { for _, loc := range item.AvailableLocations { - if strings.Contains(strings.ToLower(loc), regionLower) { + if strings.EqualFold(loc, region) { return true } } diff --git a/pkg/cmd/gpucreate/gpucreate_test.go b/pkg/cmd/gpucreate/gpucreate_test.go index 5f71cdae..78054a8b 100644 --- a/pkg/cmd/gpucreate/gpucreate_test.go +++ b/pkg/cmd/gpucreate/gpucreate_test.go @@ -27,6 +27,7 @@ type MockGPUCreateStore struct { CreatedWorkspaces []*entity.Workspace DeletedWorkspaceIDs []string FetchedLifeCycleScriptIDs []string + GetInstanceTypesCallCount int } func NewMockGPUCreateStore() *MockGPUCreateStore { @@ -127,6 +128,7 @@ func (m *MockGPUCreateStore) RedeemCouponCode(organizationID string, code string } func (m *MockGPUCreateStore) GetInstanceTypes(_ bool) (*gpusearch.InstanceTypesResponse, error) { + m.GetInstanceTypesCallCount++ // Return a default set of instance types for testing return &gpusearch.InstanceTypesResponse{ Items: []gpusearch.InstanceType{ @@ -776,16 +778,58 @@ func TestCreateDryRunWithExplicitTypesDoesNotProvision(t *testing.T) { assert.Empty(t, mock.CreatedWorkspaces) } +func TestCreateFetchesCatalogAtMostOnce(t *testing.T) { + tests := []struct { + name string + args []string + wantFetchCount int + }{ + { + name: "explicit type, no region — no catalog fetch needed", + args: []string{"no-fetch", "--type", "g5.xlarge", "--dry-run"}, + wantFetchCount: 0, + }, + { + name: "auto-search needs catalog once", + args: []string{"auto", "--dry-run"}, + wantFetchCount: 1, + }, + { + name: "region triggers catalog fetch, reused across validations", + args: []string{"reg", "--type", "g5.xlarge", "--region", "us-east-1", "--dry-run"}, + wantFetchCount: 1, + }, + { + name: "region + auto-search shares one fetch across validators and resolveInstanceTypes", + args: []string{"reg-auto", "--region", "us-east-1", "--dry-run"}, + wantFetchCount: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := NewMockGPUCreateStore() + term := terminal.New() + + cmd := NewCmdGPUCreate(term, mock) + cmd.SetArgs(tt.args) + _ = cmd.Execute() // success not required — count assertion is the point + assert.Equal(t, tt.wantFetchCount, mock.GetInstanceTypesCallCount, + "GetInstanceTypes should be called %d time(s) for args %v", tt.wantFetchCount, tt.args) + }) + } +} + func TestGetFilteredInstanceTypesDefaults(t *testing.T) { mock := NewMockGPUCreateStore() + resp, err := mock.GetInstanceTypes(false) + assert.NoError(t, err) // Get instance types with no user filters (uses defaults): // - 24GB VRAM (>= 20GB total VRAM requirement) // - 500GB disk (>= 500GB requirement) // - A10G GPU = 8.6 capability (>= 8.0 requirement) // - 5m boot time (< 7m requirement) - specs, err := getFilteredInstanceTypes(mock, &searchFilterFlags{}) - assert.NoError(t, err) + specs := getFilteredInstanceTypes(resp.Items, &searchFilterFlags{}) assert.Len(t, specs, 1) assert.Equal(t, "g5.xlarge", specs[0].Type) assert.Equal(t, 500.0, specs[0].DiskGB) // Should use the instance's disk size @@ -793,20 +837,22 @@ func TestGetFilteredInstanceTypesDefaults(t *testing.T) { func TestGetFilteredInstanceTypesWithGPUName(t *testing.T) { mock := NewMockGPUCreateStore() + resp, err := mock.GetInstanceTypes(false) + assert.NoError(t, err) // Filter by GPU name that matches the mock data - specs, err := getFilteredInstanceTypes(mock, &searchFilterFlags{gpuName: "A10G"}) - assert.NoError(t, err) + specs := getFilteredInstanceTypes(resp.Items, &searchFilterFlags{gpuName: "A10G"}) assert.Len(t, specs, 1) assert.Equal(t, "g5.xlarge", specs[0].Type) } func TestGetFilteredInstanceTypesNoMatch(t *testing.T) { mock := NewMockGPUCreateStore() + resp, err := mock.GetInstanceTypes(false) + assert.NoError(t, err) // Filter by GPU name that doesn't match - specs, err := getFilteredInstanceTypes(mock, &searchFilterFlags{gpuName: "H100"}) - assert.NoError(t, err) + specs := getFilteredInstanceTypes(resp.Items, &searchFilterFlags{gpuName: "H100"}) assert.Len(t, specs, 0) }