Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions pkg/cmd/gpucreate/gpucreate.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ type GPUCreateStore interface {
GetWorkspace(workspaceID string) (*entity.Workspace, error)
CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error)
DeleteWorkspace(workspaceID string) (*entity.Workspace, error)
GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error)
GetAllInstanceTypesWithCloudCreds(orgID string) (*gpusearch.AllInstanceTypesResponse, error)
GetLaunchable(launchableID string) (*store.LaunchableResponse, error)
GetLaunchableLifeCycleScript(launchableID, scriptID string) (*store.LifeCycleScriptResponse, error)
RedeemCouponCode(organizationID string, code string) (*store.RedeemCouponCodeResponse, error)
Expand Down Expand Up @@ -768,11 +768,11 @@ func newCreateContext(t *terminal.Terminal, store GPUCreateStore, opts GPUCreate
}
ctx.org = org

// Fetch instance types with workspace groups
allInstanceTypes, err := store.GetAllInstanceTypesWithWorkspaceGroups(org.ID)
// Fetch instance types with cloud credentials.
allInstanceTypes, err := store.GetAllInstanceTypesWithCloudCreds(org.ID)
if err != nil {
ctx.logf("Warning: could not fetch instance types with workspace groups: %s\n", err.Error())
ctx.logf("Falling back to default workspace group\n")
ctx.logf("Warning: could not fetch instance types with cloud credentials: %s\n", err.Error())
ctx.logf("Falling back to default cloud credential\n")
}
ctx.allInstanceTypes = allInstanceTypes

Expand All @@ -791,7 +791,7 @@ func (c *createContext) validateInstanceTypeAvailability(instanceType string) er
if c.allInstanceTypes == nil {
return nil
}
if c.allInstanceTypes.GetWorkspaceGroupID(instanceType) != "" {
if c.allInstanceTypes.GetCloudCredID(instanceType) != "" {
return nil
}
if !c.allInstanceTypes.HasInstanceType(instanceType) {
Expand Down Expand Up @@ -1042,8 +1042,8 @@ func (c *createContext) createWorkspace(name string, spec InstanceSpec) (*entity
}

if c.allInstanceTypes != nil {
if wgID := c.allInstanceTypes.GetWorkspaceGroupID(spec.Type); wgID != "" {
cwOptions.WorkspaceGroupID = wgID
if cloudCredID := c.allInstanceTypes.GetCloudCredID(spec.Type); cloudCredID != "" {
cwOptions.WithCloudCredID(cloudCredID)
}
}

Expand All @@ -1060,7 +1060,7 @@ func (c *createContext) createWorkspace(name string, spec InstanceSpec) (*entity
if cwOptions.WorkspaceGroupID == "" {
if c.allInstanceTypes == nil {
return nil, breverrors.NewValidationError(fmt.Sprintf(
"could not resolve workspace group for %q (instance-type listing was unavailable); please retry",
"could not resolve cloud credential for %q (instance-type listing was unavailable); please retry",
spec.Type,
))
}
Expand Down Expand Up @@ -1176,11 +1176,22 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI
return
}

wsReq := info.CreateWorkspaceRequest
applyLaunchableWorkspaceRequest(cwOptions, info.CreateWorkspaceRequest)
applyLaunchableBuildRequest(cwOptions, info.BuildRequest)
applyLaunchableFile(cwOptions, info.File)
applyLaunchableLabels(cwOptions, launchableID, info)
}

// Use launchable's workspace group if not already resolved from instance types
func applyLaunchableWorkspaceRequest(cwOptions *store.CreateWorkspacesOptions, wsReq store.LaunchableWorkspaceRequest) {
// Use launchable's cloud credential if not already resolved from instance types.
if cwOptions.CloudCredID == "" && wsReq.CloudCredID != "" {
cwOptions.WithCloudCredID(wsReq.CloudCredID)
}
if cwOptions.WorkspaceGroupID == "" && wsReq.WorkspaceGroupID != "" {
cwOptions.WorkspaceGroupID = wsReq.WorkspaceGroupID
if cwOptions.CloudCredID == "" {
cwOptions.CloudCredID = wsReq.WorkspaceGroupID
}
}

// Location / sub-location
Expand All @@ -1198,8 +1209,13 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI
cwOptions.DiskStorage = normalizeDiskStorage(wsReq.Storage)
}

if len(wsReq.FirewallRules) > 0 {
cwOptions.FirewallRules = resolveFirewallRulesClientIP(wsReq.FirewallRules, publicIPLookup)
}
}

func applyLaunchableBuildRequest(cwOptions *store.CreateWorkspacesOptions, build store.LaunchableBuildRequest) {
// Build configuration from launchable
build := info.BuildRequest
switch {
case build.VMBuild != nil:
cwOptions.VMBuild = build.VMBuild
Expand All @@ -1219,18 +1235,18 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI
}
cwOptions.PortMappings = portMappings
}
}

if len(wsReq.FirewallRules) > 0 {
cwOptions.FirewallRules = resolveFirewallRulesClientIP(wsReq.FirewallRules, publicIPLookup)
}

func applyLaunchableFile(cwOptions *store.CreateWorkspacesOptions, file *store.LaunchableFile) {
// Files from launchable
if info.File != nil {
if file != nil {
cwOptions.Files = []map[string]string{
{"url": info.File.URL, "path": info.File.Path},
{"url": file.URL, "path": file.Path},
}
}
}

func applyLaunchableLabels(cwOptions *store.CreateWorkspacesOptions, launchableID string, info *store.LaunchableResponse) {
// Labels for tracking and UI rendering — merge with any existing labels
var labels map[string]string
if existing, ok := cwOptions.Labels.(map[string]string); ok && existing != nil {
Expand All @@ -1239,7 +1255,8 @@ func applyLaunchableConfig(cwOptions *store.CreateWorkspacesOptions, launchableI
labels = make(map[string]string)
}
labels["launchableId"] = launchableID
labels["launchableInstanceType"] = wsReq.InstanceType
labels["launchableInstanceType"] = info.CreateWorkspaceRequest.InstanceType
labels["cloudCredId"] = cwOptions.CloudCredID
labels["workspaceGroupId"] = cwOptions.WorkspaceGroupID
labels["launchableCreatedByUserId"] = info.CreatedByUserID
labels["launchableCreatedByOrgId"] = info.CreatedByOrgID
Expand Down
39 changes: 37 additions & 2 deletions pkg/cmd/gpucreate/gpucreate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/brevdev/brev-cli/pkg/store"
"github.com/brevdev/brev-cli/pkg/terminal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// MockGPUCreateStore is a mock implementation of GPUCreateStore for testing
Expand All @@ -24,6 +25,7 @@ type MockGPUCreateStore struct {
CreateError error
CreateErrorTypes map[string]error // Errors for specific instance types
DeleteError error
CreatedOptions []*store.CreateWorkspacesOptions
CreatedWorkspaces []*entity.Workspace
DeletedWorkspaceIDs []string
FetchedLifeCycleScriptIDs []string
Expand Down Expand Up @@ -85,6 +87,7 @@ func (m *MockGPUCreateStore) CreateWorkspace(organizationID string, options *sto
Status: entity.Running,
}
m.Workspaces[ws.ID] = ws
m.CreatedOptions = append(m.CreatedOptions, options)
m.CreatedWorkspaces = append(m.CreatedWorkspaces, ws)
return ws, nil
}
Expand All @@ -104,7 +107,7 @@ func (m *MockGPUCreateStore) GetWorkspaceByNameOrID(orgID string, nameOrID strin
return []entity.Workspace{}, nil
}

func (m *MockGPUCreateStore) GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error) {
func (m *MockGPUCreateStore) GetAllInstanceTypesWithCloudCreds(orgID string) (*gpusearch.AllInstanceTypesResponse, error) {
return nil, nil
}

Expand Down Expand Up @@ -247,7 +250,8 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test

applyLaunchableConfig(cwOptions, "env-abc123", info)

// Workspace group from launchable
// Cloud credential from launchable compatibility input.
assert.Equal(t, "GCP", cwOptions.CloudCredID)
assert.Equal(t, "GCP", cwOptions.WorkspaceGroupID)
// Location / sub-location
assert.Equal(t, "us-west1", cwOptions.Location)
Expand All @@ -274,13 +278,15 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test
assert.True(t, ok)
assert.Equal(t, "env-abc123", labels["launchableId"])
assert.Equal(t, "n2-standard-4", labels["launchableInstanceType"])
assert.Equal(t, "GCP", labels["cloudCredId"])
assert.Equal(t, "GCP", labels["workspaceGroupId"])
assert.Equal(t, "user-1", labels["launchableCreatedByUserId"])
assert.Equal(t, "org-1", labels["launchableCreatedByOrgId"])
})

t.Run("preserves existing workspace group", func(t *testing.T) {
cwOptions := &store.CreateWorkspacesOptions{
CloudCredID: "existing-wg",
WorkspaceGroupID: "existing-wg",
}
info := &store.LaunchableResponse{
Expand All @@ -293,6 +299,7 @@ func TestApplyLaunchableConfig(t *testing.T) { //nolint:funlen // test
applyLaunchableConfig(cwOptions, "env-abc", info)

assert.Equal(t, "existing-wg", cwOptions.WorkspaceGroupID)
assert.Equal(t, "existing-wg", cwOptions.CloudCredID)
})

t.Run("storage already has Gi suffix", func(t *testing.T) {
Expand Down Expand Up @@ -1116,6 +1123,34 @@ func TestCreateInstancesWithTypeSkipsUnavailableType(t *testing.T) {
assert.Empty(t, mock.CreatedWorkspaces, "CreateWorkspace must not be called when no workspace group is available")
}

func TestCreateInstancesWithTypeSetsCloudCredIDFromCatalog(t *testing.T) {
mock := NewMockGPUCreateStore()
ctx := &createContext{
t: terminal.New(),
store: mock,
opts: GPUCreateOptions{Count: 1, Parallel: 1, Name: "jt-4"},
org: mock.Org,
user: mock.User,
piped: true,
allInstanceTypes: &gpusearch.AllInstanceTypesResponse{
AllInstanceTypes: []gpusearch.InstanceType{
{
Type: "hyperstack_H100_sxm5x8",
CloudCredID: "cc-shadeform",
},
},
},
}
ctx.logf = func(_ string, _ ...interface{}) {}

result := ctx.createInstancesWithType(InstanceSpec{Type: "hyperstack_H100_sxm5x8"}, 0, 1)

assert.False(t, result.hadFailure)
require.Len(t, mock.CreatedOptions, 1)
assert.Equal(t, "cc-shadeform", mock.CreatedOptions[0].CloudCredID)
assert.Equal(t, "cc-shadeform", mock.CreatedOptions[0].WorkspaceGroupID)
}

func TestCreateInstancesWithTypeBypassesValidationForLaunchable(t *testing.T) {
mock := NewMockGPUCreateStore()
ctx := &createContext{
Expand Down
27 changes: 24 additions & 3 deletions pkg/cmd/gpusearch/gpusearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ type WorkspaceGroup struct {
PlatformType string `json:"platformType"`
}

// CloudCred represents a cloud credential that can run an instance type.
type CloudCred struct {
ID string `json:"id"`
Name string `json:"name"`
PlatformType string `json:"platformType"`
TenantType string `json:"tenantType"`
}

// InstanceType represents an instance type from the API
type InstanceType struct {
Type string `json:"type"`
Expand All @@ -69,6 +77,8 @@ type InstanceType struct {
SubLocation string `json:"sub_location"`
AvailableLocations []string `json:"available_locations"`
Provider string `json:"provider"`
CloudCredID string `json:"cloud_cred_id"`
CloudCreds []CloudCred `json:"cloud_creds"`
WorkspaceGroups []WorkspaceGroup `json:"workspace_groups"`
EstimatedDeployTime string `json:"estimated_deploy_time"`
Stoppable bool `json:"stoppable"`
Expand All @@ -81,15 +91,21 @@ type InstanceTypesResponse struct {
Items []InstanceType `json:"items"`
}

// AllInstanceTypesResponse represents the authenticated API response with workspace groups
// AllInstanceTypesResponse represents the authenticated API response with cloud credentials.
type AllInstanceTypesResponse struct {
AllInstanceTypes []InstanceType `json:"allInstanceTypes"`
}

// GetWorkspaceGroupID returns the workspace group ID for an instance type, or empty string if not found
func (r *AllInstanceTypesResponse) GetWorkspaceGroupID(instanceType string) string {
// GetCloudCredID returns the cloud credential ID for an instance type, or empty string if not found.
func (r *AllInstanceTypesResponse) GetCloudCredID(instanceType string) string {
for _, it := range r.AllInstanceTypes {
if it.Type == instanceType {
if it.CloudCredID != "" {
return it.CloudCredID
}
if len(it.CloudCreds) > 0 {
return it.CloudCreds[0].ID
}
if len(it.WorkspaceGroups) > 0 {
return it.WorkspaceGroups[0].ID
}
Expand All @@ -98,6 +114,11 @@ func (r *AllInstanceTypesResponse) GetWorkspaceGroupID(instanceType string) stri
return ""
}

// GetWorkspaceGroupID is a compatibility alias while create still accepts workspaceGroupId.
func (r *AllInstanceTypesResponse) GetWorkspaceGroupID(instanceType string) string {
return r.GetCloudCredID(instanceType)
}

// HasInstanceType reports whether the type exists in the API listing, independent of capacity.
func (r *AllInstanceTypesResponse) HasInstanceType(instanceType string) bool {
for _, it := range r.AllInstanceTypes {
Expand Down
Loading
Loading