diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go index 88325e5cf..50fe68f4d 100644 --- a/drivers/smb/driver.go +++ b/drivers/smb/driver.go @@ -3,9 +3,11 @@ package smb import ( "context" "errors" + "net" "path" "path/filepath" "strings" + "sync" "github.com/OpenListTeam/OpenList/v4/internal/driver" "github.com/OpenListTeam/OpenList/v4/internal/model" @@ -19,7 +21,11 @@ type SMB struct { lastConnTime int64 model.Storage Addition - fs *smb2.Share + connMu sync.Mutex + activeOps int + conn net.Conn + session *smb2.Session + fs *smb2.Share } func (d *SMB) Config() driver.Config { @@ -38,18 +44,17 @@ func (d *SMB) Init(ctx context.Context) error { } func (d *SMB) Drop(ctx context.Context) error { - if d.fs != nil { - _ = d.fs.Umount() - } - return nil + return d.closeFS() } func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return nil, err } + defer release() fullPath := dir.GetPath() - rawFiles, err := d.fs.ReadDir(fullPath) + rawFiles, err := fs.ReadDir(fullPath) if err != nil { d.cleanLastConnTime() return nil, err @@ -72,11 +77,18 @@ func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m } func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return nil, err } + needRelease := true + defer func() { + if needRelease { + release() + } + }() fullPath := file.GetPath() - remoteFile, err := d.fs.Open(fullPath) + remoteFile, err := fs.Open(fullPath) if err != nil { d.cleanLastConnTime() return nil, err @@ -87,19 +99,25 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m Limiter: stream.ServerDownloadLimit, Ctx: ctx, } + needRelease = false return &model.Link{ - RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile), - SyncClosers: utils.NewSyncClosers(remoteFile), + RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile), + SyncClosers: utils.NewSyncClosers(remoteFile, utils.CloseFunc(func() error { + release() + return nil + })), RequireReference: true, }, nil } func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return err } + defer release() fullPath := filepath.Join(parentDir.GetPath(), dirName) - err := d.fs.MkdirAll(fullPath, 0700) + err = fs.MkdirAll(fullPath, 0700) if err != nil { d.cleanLastConnTime() return err @@ -109,12 +127,14 @@ func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) } func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return err } + defer release() srcPath := srcObj.GetPath() dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName()) - err := d.fs.Rename(srcPath, dstPath) + err = fs.Rename(srcPath, dstPath) if err != nil { d.cleanLastConnTime() return err @@ -124,12 +144,14 @@ func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error { } func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return err } + defer release() srcPath := srcObj.GetPath() dstPath := filepath.Join(filepath.Dir(srcPath), newName) - err := d.fs.Rename(srcPath, dstPath) + err = fs.Rename(srcPath, dstPath) if err != nil { d.cleanLastConnTime() return err @@ -139,12 +161,13 @@ func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) erro } func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { - if err := d.checkConn(ctx); err != nil { + _, release, err := d.acquireConn(ctx) + if err != nil { return err } + defer release() srcPath := srcObj.GetPath() dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName()) - var err error if srcObj.IsDir() { err = d.CopyDir(srcPath, dstPath) } else { @@ -159,15 +182,16 @@ func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { } func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return err } - var err error + defer release() fullPath := obj.GetPath() if obj.IsDir() { - err = d.fs.RemoveAll(fullPath) + err = fs.RemoveAll(fullPath) } else { - err = d.fs.Remove(fullPath) + err = fs.Remove(fullPath) } if err != nil { d.cleanLastConnTime() @@ -178,11 +202,13 @@ func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { } func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return err } + defer release() fullPath := filepath.Join(dstDir.GetPath(), stream.GetName()) - out, err := d.fs.Create(fullPath) + out, err := fs.Create(fullPath) if err != nil { d.cleanLastConnTime() return err @@ -191,7 +217,7 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream defer func() { _ = out.Close() if errors.Is(err, context.Canceled) { - _ = d.fs.Remove(fullPath) + _ = fs.Remove(fullPath) } }() err = utils.CopyWithCtx(ctx, out, driver.NewLimitedUploadStream(ctx, stream), stream.GetSize(), up) @@ -202,13 +228,16 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream } func (d *SMB) GetDetails(ctx context.Context) (*model.StorageDetails, error) { - if err := d.checkConn(ctx); err != nil { + fs, release, err := d.acquireConn(ctx) + if err != nil { return nil, err } - stat, err := d.fs.Statfs(d.RootFolderPath) + defer release() + stat, err := fs.Statfs(d.RootFolderPath) if err != nil { return nil, err } + d.updateLastConnTime() total := int64(stat.BlockSize() * stat.TotalBlockCount()) free := int64(stat.BlockSize() * stat.AvailableBlockCount()) return &model.StorageDetails{ diff --git a/drivers/smb/util.go b/drivers/smb/util.go index 6ae365f8f..6834feab5 100644 --- a/drivers/smb/util.go +++ b/drivers/smb/util.go @@ -2,8 +2,10 @@ package smb import ( "context" + "errors" "fmt" "io/fs" + "net" "os" "path/filepath" "sync/atomic" @@ -29,37 +31,103 @@ func (d *SMB) getLastConnTime() time.Time { func (d *SMB) initFS(ctx context.Context) error { _, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("SMB.initFS:%p", d), func() (any, error) { - return nil, d._initFS(ctx) + d.connMu.Lock() + defer d.connMu.Unlock() + return nil, d.initFSLocked(ctx) }) return err } + func (d *SMB) _initFS(ctx context.Context) error { + d.connMu.Lock() + defer d.connMu.Unlock() + return d.initFSLocked(ctx) +} + +func (d *SMB) initFSLocked(ctx context.Context) error { + _ = d.closeFSLocked() dialer := &smb2.Dialer{ Initiator: &smb2.NTLMInitiator{ User: d.Username, Password: d.Password, }, } - s, err := dialer.Dial(ctx, d.Address) + conn, err := net.Dial("tcp", d.Address) if err != nil { return err } - d.fs, err = s.Mount(d.ShareName) + s, err := dialer.DialConn(ctx, conn, d.Address) if err != nil { + _ = conn.Close() return err } + fs, err := s.Mount(d.ShareName) + if err != nil { + _ = s.Logoff() + _ = conn.Close() + return err + } + d.conn = conn + d.session = s + d.fs = fs d.updateLastConnTime() + return nil +} + +func (d *SMB) closeFS() error { + d.connMu.Lock() + defer d.connMu.Unlock() + return d.closeFSLocked() +} + +func (d *SMB) closeFSLocked() error { + var err error + if d.fs != nil { + err = errors.Join(err, d.fs.Umount()) + d.fs = nil + } + if d.session != nil { + err = errors.Join(err, d.session.Logoff()) + d.session = nil + } + if d.conn != nil { + err = errors.Join(err, d.conn.Close()) + d.conn = nil + } + d.cleanLastConnTime() return err } func (d *SMB) checkConn(ctx context.Context) error { - if time.Since(d.getLastConnTime()) < 5*time.Minute { - return nil + _, release, err := d.acquireConn(ctx) + if release != nil { + release() } - if d.fs != nil { - _ = d.fs.Umount() + return err +} + +func (d *SMB) acquireConn(ctx context.Context) (*smb2.Share, func(), error) { + d.connMu.Lock() + defer d.connMu.Unlock() + + if d.fs == nil || (time.Since(d.getLastConnTime()) >= 5*time.Minute && d.activeOps == 0) { + if err := d.initFSLocked(ctx); err != nil { + return nil, nil, err + } + } + if d.fs == nil { + return nil, nil, errors.New("smb share is not initialized") + } + d.activeOps++ + return d.fs, d.releaseConn, nil +} + +func (d *SMB) releaseConn() { + d.connMu.Lock() + defer d.connMu.Unlock() + if d.activeOps > 0 { + d.activeOps-- } - return d.initFS(ctx) } // CopyFile File copies a single file from src to dst