Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 138 additions & 34 deletions pkg/cmd/gpucreate/gpucreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type CreateResult struct {
// searchFilterFlags holds the search filter flag values for create
type searchFilterFlags struct {
gpuName string
region string
provider string
minVRAM float64
minTotalVRAM float64
Expand All @@ -144,7 +145,7 @@ type searchFilterFlags struct {

// hasUserFilters returns true if the user specified any search filter flags
func (f *searchFilterFlags) hasUserFilters() bool {
return f.gpuName != "" || f.provider != "" || f.minVRAM > 0 || f.minTotalVRAM > 0 ||
return f.gpuName != "" || f.region != "" || f.provider != "" || f.minVRAM > 0 || f.minTotalVRAM > 0 ||
f.minCapability > 0 || f.minDisk > 0 || f.maxBootTime > 0 ||
f.stoppable || f.rebootable || f.flexPorts
}
Expand Down Expand Up @@ -180,6 +181,12 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra
name = args[0]
}

// Normalize --region: provider region names are case-sensitive on the
// server side (e.g. GCP rejects "US-WEST1", accepts "us-west1"), while
// our client-side validation is case-insensitive. Lowercase here so
// what we validate matches what we send.
filters.region = strings.ToLower(strings.TrimSpace(filters.region))

launchableID, err := parseLaunchableID(launchable)
if err != nil {
return err
Expand Down Expand Up @@ -233,15 +240,44 @@ func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra
ComposeFile: composeFile,
LaunchableID: launchableID,
LaunchableInfo: launchableInfo,
Region: filters.region,
}

// Fetch the instance-type catalog at most once per invocation. Region
// validation, auto-search, and dry-run preview all share the same response
// so we never issue more than one round-trip. --type explicit + no region
// skips the fetch entirely (runDryRun returns early on explicit specs).
var catalogItems []gpusearch.InstanceType
needCatalog := filters.region != "" || (len(types) == 0 && launchableID == "")
if needCatalog {
resp, err := gpuCreateStore.GetInstanceTypes(false)
if err != nil {
return breverrors.WrapAndTrace(err)
}
if resp != nil {
catalogItems = resp.Items
}
}

opts.InstanceTypes, err = resolveInstanceTypes(cmd, gpuCreateStore, opts, types, &filters)
if filters.region != "" {
if err := validateRegionExists(filters.region, catalogItems); err != nil {
return err
}
}

opts.InstanceTypes, err = resolveInstanceTypes(cmd, catalogItems, opts, types, &filters)
if err != nil {
return err
}

if filters.region != "" {
if err := validateRegion(filters.region, opts.InstanceTypes, catalogItems); err != nil {
return err
}
}
Comment thread
hakhandelwal11 marked this conversation as resolved.

if dryRun {
return runDryRun(t, gpuCreateStore, opts.InstanceTypes, &filters)
return runDryRun(t, catalogItems, opts.InstanceTypes, &filters)
}

return RunGPUCreate(t, gpuCreateStore, opts)
Expand Down Expand Up @@ -281,6 +317,7 @@ func registerCreateFlags(cmd *cobra.Command, name, instanceTypes *string, count,
cmd.Flags().StringVar(containerImage, "container-image", "", "Container image URL (required for container mode)")
cmd.Flags().StringVar(composeFile, "compose-file", "", "Docker compose file path or URL (required for compose mode)")
cmd.Flags().StringVarP(launchable, "launchable", "l", "", "Launchable ID or URL to deploy (e.g., env-XXX or console URL)")
cmd.Flags().StringVarP(&filters.region, "region", "r", "", "Region/location to deploy the instance (e.g., us-east-1, us-central1)")

cmd.Flags().StringVarP(&filters.gpuName, "gpu-name", "g", "", "Filter by GPU name (e.g., A100, H100)")
cmd.Flags().StringVar(&filters.provider, "provider", "", "Filter by provider/cloud (e.g., aws, gcp)")
Expand Down Expand Up @@ -318,6 +355,7 @@ type GPUCreateOptions struct {
ComposeFile string
LaunchableID string
LaunchableInfo *store.LaunchableResponse // populated when LaunchableID is set
Region string
}

// parseLaunchableID extracts a launchable ID from either a raw ID (env-XXX) or
Expand Down Expand Up @@ -447,8 +485,9 @@ func launchableBuildModeName(info *store.LaunchableResponse) string {
}
}

// resolveInstanceTypes determines instance types from launchable, flags, or filters
func resolveInstanceTypes(cmd *cobra.Command, gpuCreateStore GPUCreateStore, opts GPUCreateOptions, types []InstanceSpec, filters *searchFilterFlags) ([]InstanceSpec, error) {
// resolveInstanceTypes determines instance types from launchable, flags, or filters.
// Operates on a pre-fetched catalog when auto-search is needed.
func resolveInstanceTypes(cmd *cobra.Command, items []gpusearch.InstanceType, opts GPUCreateOptions, types []InstanceSpec, filters *searchFilterFlags) ([]InstanceSpec, error) {
if opts.LaunchableID != "" && len(types) == 0 && !cmd.Flags().Changed("type") {
instanceType := ""
if opts.LaunchableInfo != nil {
Expand All @@ -461,10 +500,7 @@ func resolveInstanceTypes(cmd *cobra.Command, gpuCreateStore GPUCreateStore, opt
}

if len(types) == 0 {
filtered, err := getFilteredInstanceTypes(gpuCreateStore, filters)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
filtered := getFilteredInstanceTypes(items, filters)
if len(filtered) == 0 {
return nil, breverrors.NewValidationError("no GPU instances match the specified filters. Try 'brev search' to see available options")
}
Expand Down Expand Up @@ -495,15 +531,11 @@ func parseStartupScript(value string) (string, error) {
return value, nil
}

// searchInstances fetches and filters GPU instances using user-provided filters merged with defaults
func searchInstances(s GPUCreateStore, filters *searchFilterFlags) ([]gpusearch.GPUInstanceInfo, float64, error) {
response, err := s.GetInstanceTypes(false)
if err != nil {
return nil, 0, breverrors.WrapAndTrace(err)
}

if response == nil || len(response.Items) == 0 {
return nil, 0, nil
// searchInstances filters GPU instances using user-provided filters merged with defaults.
// Operates on a pre-fetched catalog so callers can share one network round-trip.
func searchInstances(items []gpusearch.InstanceType, filters *searchFilterFlags) ([]gpusearch.GPUInstanceInfo, float64) {
if len(items) == 0 {
return nil, 0
}

minTotalVRAM := orDefault(filters.minTotalVRAM, defaultMinTotalVRAM)
Expand All @@ -518,21 +550,19 @@ func searchInstances(s GPUCreateStore, filters *searchFilterFlags) ([]gpusearch.
sortBy = "price"
}

instances := gpusearch.ProcessInstances(response.Items)
filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.provider, "", filters.minVRAM,
instances := gpusearch.ProcessInstances(items)
filtered := gpusearch.FilterInstances(instances, filters.gpuName, filters.region, filters.provider, "", filters.minVRAM,
minTotalVRAM, minCapability, 0, minDisk, 0, maxBootTime, filters.stoppable, filters.rebootable, filters.flexPorts, true)
gpusearch.SortInstances(filtered, sortBy, filters.descending)

return filtered, minDisk, nil
return filtered, minDisk
}

// getFilteredInstanceTypes fetches GPU instance types using user-provided filters
// getFilteredInstanceTypes filters GPU instance types using user-provided filters
// merged with defaults. When a filter flag is not set, the default value is used.
func getFilteredInstanceTypes(s GPUCreateStore, filters *searchFilterFlags) ([]InstanceSpec, error) {
filtered, minDisk, err := searchInstances(s, filters)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
// Operates on a pre-fetched catalog.
func getFilteredInstanceTypes(items []gpusearch.InstanceType, filters *searchFilterFlags) []InstanceSpec {
filtered, minDisk := searchInstances(items, filters)

var specs []InstanceSpec
for _, inst := range filtered {
Expand All @@ -543,20 +573,18 @@ func getFilteredInstanceTypes(s GPUCreateStore, filters *searchFilterFlags) ([]I
specs = append(specs, InstanceSpec{Type: inst.Type, DiskGB: diskGB})
}

return specs, nil
return specs
}

// runDryRun shows the instance types that would be used without creating anything
func runDryRun(t *terminal.Terminal, s GPUCreateStore, specs []InstanceSpec, filters *searchFilterFlags) error {
// runDryRun shows the instance types that would be used without creating anything.
// Operates on a pre-fetched catalog.
func runDryRun(t *terminal.Terminal, items []gpusearch.InstanceType, specs []InstanceSpec, filters *searchFilterFlags) error {
if len(specs) > 0 {
t.Print(formatInstanceSpecs(specs))
return nil
}

filtered, _, err := searchInstances(s, filters)
if err != nil {
return breverrors.WrapAndTrace(err)
}
filtered, _ := searchInstances(items, filters)

piped := gpusearch.IsStdoutPiped()
if err := gpusearch.DisplayGPUResults(t, filtered, false, piped, false); err != nil {
Expand All @@ -565,6 +593,78 @@ func runDryRun(t *terminal.Terminal, s GPUCreateStore, specs []InstanceSpec, fil
return nil
}

// validateRegionExists checks that the given region appears in at least one instance type in the catalog.
// Operates on a pre-fetched catalog.
func validateRegionExists(region string, items []gpusearch.InstanceType) error {
if len(items) == 0 {
return nil
}

for _, item := range items {
if typeSupportsRegion(item, region) {
return nil
}
}

return breverrors.NewValidationError(
fmt.Sprintf("region %q is not offered by any instance type -- use 'brev search --json' to list valid regions",
region),
)
}

// validateRegion checks that every requested instance type is available in the given region.
// Operates on a pre-fetched catalog.
func validateRegion(region string, types []InstanceSpec, items []gpusearch.InstanceType) error {
if len(types) == 0 || len(items) == 0 {
return nil
}

catalog := make(map[string]gpusearch.InstanceType, len(items))
for _, item := range items {
catalog[item.Type] = item
}

var unsupported []string
var unknown []string

for _, spec := range types {
item, ok := catalog[spec.Type]
if !ok {
unknown = append(unknown, spec.Type)
continue
}
if !typeSupportsRegion(item, region) {
unsupported = append(unsupported, spec.Type)
}
}

if len(unknown) > 0 {
return breverrors.NewValidationError(
fmt.Sprintf("unknown instance type(s) %s -- use 'brev search' to list available types", strings.Join(unknown, ", ")),
)
}
if len(unsupported) > 0 {
return breverrors.NewValidationError(
fmt.Sprintf("region %q is not available for instance type(s) %s -- use 'brev search --region %s' to find compatible types",
region, strings.Join(unsupported, ", "), region),
)
}
return nil
}

// typeSupportsRegion reports whether an instance type lists the given region
// in its AvailableLocations. Match is case-insensitive but must be an exact
// equality — substring matching would let e.g. --region "us" pass validation
// against "us-east-1", and then we'd forward "us" to the server as the region.
func typeSupportsRegion(item gpusearch.InstanceType, region string) bool {
for _, loc := range item.AvailableLocations {
if strings.EqualFold(loc, region) {
return true
}
}
return false
}

// orDefault returns val if it's non-zero, otherwise returns def
func orDefault(val, def float64) float64 {
if val > 0 {
Expand Down Expand Up @@ -1047,6 +1147,10 @@ func (c *createContext) createWorkspace(name string, spec InstanceSpec) (*entity
}
}

if c.opts.Region != "" {
cwOptions.Location = c.opts.Region
}

// Apply launchable config or build mode
if c.opts.LaunchableID != "" {
applyLaunchableConfig(cwOptions, c.opts.LaunchableID, c.opts.LaunchableInfo)
Expand Down
58 changes: 52 additions & 6 deletions pkg/cmd/gpucreate/gpucreate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type MockGPUCreateStore struct {
CreatedWorkspaces []*entity.Workspace
DeletedWorkspaceIDs []string
FetchedLifeCycleScriptIDs []string
GetInstanceTypesCallCount int
}

func NewMockGPUCreateStore() *MockGPUCreateStore {
Expand Down Expand Up @@ -127,6 +128,7 @@ func (m *MockGPUCreateStore) RedeemCouponCode(organizationID string, code string
}

func (m *MockGPUCreateStore) GetInstanceTypes(_ bool) (*gpusearch.InstanceTypesResponse, error) {
m.GetInstanceTypesCallCount++
// Return a default set of instance types for testing
return &gpusearch.InstanceTypesResponse{
Items: []gpusearch.InstanceType{
Expand Down Expand Up @@ -776,37 +778,81 @@ func TestCreateDryRunWithExplicitTypesDoesNotProvision(t *testing.T) {
assert.Empty(t, mock.CreatedWorkspaces)
}

func TestCreateFetchesCatalogAtMostOnce(t *testing.T) {
tests := []struct {
name string
args []string
wantFetchCount int
}{
{
name: "explicit type, no region — no catalog fetch needed",
args: []string{"no-fetch", "--type", "g5.xlarge", "--dry-run"},
wantFetchCount: 0,
},
{
name: "auto-search needs catalog once",
args: []string{"auto", "--dry-run"},
wantFetchCount: 1,
},
{
name: "region triggers catalog fetch, reused across validations",
args: []string{"reg", "--type", "g5.xlarge", "--region", "us-east-1", "--dry-run"},
wantFetchCount: 1,
},
{
name: "region + auto-search shares one fetch across validators and resolveInstanceTypes",
args: []string{"reg-auto", "--region", "us-east-1", "--dry-run"},
wantFetchCount: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := NewMockGPUCreateStore()
term := terminal.New()

cmd := NewCmdGPUCreate(term, mock)
cmd.SetArgs(tt.args)
_ = cmd.Execute() // success not required — count assertion is the point
assert.Equal(t, tt.wantFetchCount, mock.GetInstanceTypesCallCount,
"GetInstanceTypes should be called %d time(s) for args %v", tt.wantFetchCount, tt.args)
})
}
}

func TestGetFilteredInstanceTypesDefaults(t *testing.T) {
mock := NewMockGPUCreateStore()
resp, err := mock.GetInstanceTypes(false)
assert.NoError(t, err)

// Get instance types with no user filters (uses defaults):
// - 24GB VRAM (>= 20GB total VRAM requirement)
// - 500GB disk (>= 500GB requirement)
// - A10G GPU = 8.6 capability (>= 8.0 requirement)
// - 5m boot time (< 7m requirement)
specs, err := getFilteredInstanceTypes(mock, &searchFilterFlags{})
assert.NoError(t, err)
specs := getFilteredInstanceTypes(resp.Items, &searchFilterFlags{})
assert.Len(t, specs, 1)
assert.Equal(t, "g5.xlarge", specs[0].Type)
assert.Equal(t, 500.0, specs[0].DiskGB) // Should use the instance's disk size
}

func TestGetFilteredInstanceTypesWithGPUName(t *testing.T) {
mock := NewMockGPUCreateStore()
resp, err := mock.GetInstanceTypes(false)
assert.NoError(t, err)

// Filter by GPU name that matches the mock data
specs, err := getFilteredInstanceTypes(mock, &searchFilterFlags{gpuName: "A10G"})
assert.NoError(t, err)
specs := getFilteredInstanceTypes(resp.Items, &searchFilterFlags{gpuName: "A10G"})
assert.Len(t, specs, 1)
assert.Equal(t, "g5.xlarge", specs[0].Type)
}

func TestGetFilteredInstanceTypesNoMatch(t *testing.T) {
mock := NewMockGPUCreateStore()
resp, err := mock.GetInstanceTypes(false)
assert.NoError(t, err)

// Filter by GPU name that doesn't match
specs, err := getFilteredInstanceTypes(mock, &searchFilterFlags{gpuName: "H100"})
assert.NoError(t, err)
specs := getFilteredInstanceTypes(resp.Items, &searchFilterFlags{gpuName: "H100"})
assert.Len(t, specs, 0)
}

Expand Down
Loading
Loading