Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions drivers/smb/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package smb
import (
"context"
"errors"
"net"
"path"
"path/filepath"
"strings"
"sync"

"github.com/OpenListTeam/OpenList/v4/internal/driver"
"github.com/OpenListTeam/OpenList/v4/internal/model"
Expand All @@ -19,7 +21,11 @@ type SMB struct {
lastConnTime int64
model.Storage
Addition
fs *smb2.Share
connMu sync.Mutex
activeOps int
conn net.Conn
session *smb2.Session
fs *smb2.Share
}

func (d *SMB) Config() driver.Config {
Expand All @@ -38,18 +44,17 @@ func (d *SMB) Init(ctx context.Context) error {
}

func (d *SMB) Drop(ctx context.Context) error {
if d.fs != nil {
_ = d.fs.Umount()
}
return nil
return d.closeFS()
}

func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return nil, err
}
defer release()
fullPath := dir.GetPath()
rawFiles, err := d.fs.ReadDir(fullPath)
rawFiles, err := fs.ReadDir(fullPath)
if err != nil {
d.cleanLastConnTime()
return nil, err
Expand All @@ -72,11 +77,18 @@ func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]m
}

func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return nil, err
}
needRelease := true
defer func() {
if needRelease {
release()
}
}()
fullPath := file.GetPath()
remoteFile, err := d.fs.Open(fullPath)
remoteFile, err := fs.Open(fullPath)
if err != nil {
d.cleanLastConnTime()
return nil, err
Expand All @@ -87,19 +99,25 @@ func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*m
Limiter: stream.ServerDownloadLimit,
Ctx: ctx,
}
needRelease = false
return &model.Link{
RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile),
SyncClosers: utils.NewSyncClosers(remoteFile),
RangeReader: stream.GetRangeReaderFromMFile(file.GetSize(), mFile),
SyncClosers: utils.NewSyncClosers(remoteFile, utils.CloseFunc(func() error {
release()
return nil
})),
RequireReference: true,
}, nil
}

func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return err
}
defer release()
fullPath := filepath.Join(parentDir.GetPath(), dirName)
err := d.fs.MkdirAll(fullPath, 0700)
err = fs.MkdirAll(fullPath, 0700)
if err != nil {
d.cleanLastConnTime()
return err
Expand All @@ -109,12 +127,14 @@ func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string)
}

func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return err
}
defer release()
srcPath := srcObj.GetPath()
dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName())
err := d.fs.Rename(srcPath, dstPath)
err = fs.Rename(srcPath, dstPath)
if err != nil {
d.cleanLastConnTime()
return err
Expand All @@ -124,12 +144,14 @@ func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
}

func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return err
}
defer release()
srcPath := srcObj.GetPath()
dstPath := filepath.Join(filepath.Dir(srcPath), newName)
err := d.fs.Rename(srcPath, dstPath)
err = fs.Rename(srcPath, dstPath)
if err != nil {
d.cleanLastConnTime()
return err
Expand All @@ -139,12 +161,13 @@ func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) erro
}

func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
if err := d.checkConn(ctx); err != nil {
_, release, err := d.acquireConn(ctx)
if err != nil {
return err
}
defer release()
srcPath := srcObj.GetPath()
dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName())
var err error
if srcObj.IsDir() {
err = d.CopyDir(srcPath, dstPath)
} else {
Expand All @@ -159,15 +182,16 @@ func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
}

func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return err
}
var err error
defer release()
fullPath := obj.GetPath()
if obj.IsDir() {
err = d.fs.RemoveAll(fullPath)
err = fs.RemoveAll(fullPath)
} else {
err = d.fs.Remove(fullPath)
err = fs.Remove(fullPath)
}
if err != nil {
d.cleanLastConnTime()
Expand All @@ -178,11 +202,13 @@ func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
}

func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return err
}
defer release()
fullPath := filepath.Join(dstDir.GetPath(), stream.GetName())
out, err := d.fs.Create(fullPath)
out, err := fs.Create(fullPath)
if err != nil {
d.cleanLastConnTime()
return err
Expand All @@ -191,7 +217,7 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream
defer func() {
_ = out.Close()
if errors.Is(err, context.Canceled) {
_ = d.fs.Remove(fullPath)
_ = fs.Remove(fullPath)
}
}()
err = utils.CopyWithCtx(ctx, out, driver.NewLimitedUploadStream(ctx, stream), stream.GetSize(), up)
Expand All @@ -202,13 +228,16 @@ func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStream
}

func (d *SMB) GetDetails(ctx context.Context) (*model.StorageDetails, error) {
if err := d.checkConn(ctx); err != nil {
fs, release, err := d.acquireConn(ctx)
if err != nil {
return nil, err
}
stat, err := d.fs.Statfs(d.RootFolderPath)
defer release()
stat, err := fs.Statfs(d.RootFolderPath)
if err != nil {
return nil, err
}
d.updateLastConnTime()
total := int64(stat.BlockSize() * stat.TotalBlockCount())
free := int64(stat.BlockSize() * stat.AvailableBlockCount())
return &model.StorageDetails{
Expand Down
84 changes: 76 additions & 8 deletions drivers/smb/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package smb

import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"os"
"path/filepath"
"sync/atomic"
Expand All @@ -29,37 +31,103 @@ func (d *SMB) getLastConnTime() time.Time {

func (d *SMB) initFS(ctx context.Context) error {
_, err, _ := singleflight.AnyGroup.Do(fmt.Sprintf("SMB.initFS:%p", d), func() (any, error) {
return nil, d._initFS(ctx)
d.connMu.Lock()
defer d.connMu.Unlock()
return nil, d.initFSLocked(ctx)
})
return err
}

func (d *SMB) _initFS(ctx context.Context) error {
d.connMu.Lock()
defer d.connMu.Unlock()
return d.initFSLocked(ctx)
}

func (d *SMB) initFSLocked(ctx context.Context) error {
_ = d.closeFSLocked()
dialer := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: d.Username,
Password: d.Password,
},
}
s, err := dialer.Dial(ctx, d.Address)
conn, err := net.Dial("tcp", d.Address)
if err != nil {
return err
}
d.fs, err = s.Mount(d.ShareName)
s, err := dialer.DialConn(ctx, conn, d.Address)
if err != nil {
_ = conn.Close()
return err
}
fs, err := s.Mount(d.ShareName)
if err != nil {
_ = s.Logoff()
_ = conn.Close()
return err
}
d.conn = conn
d.session = s
d.fs = fs
d.updateLastConnTime()
return nil
}

func (d *SMB) closeFS() error {
d.connMu.Lock()
defer d.connMu.Unlock()
return d.closeFSLocked()
}

func (d *SMB) closeFSLocked() error {
var err error
if d.fs != nil {
err = errors.Join(err, d.fs.Umount())
d.fs = nil
}
if d.session != nil {
err = errors.Join(err, d.session.Logoff())
d.session = nil
}
if d.conn != nil {
err = errors.Join(err, d.conn.Close())
d.conn = nil
}
d.cleanLastConnTime()
return err
}

func (d *SMB) checkConn(ctx context.Context) error {
if time.Since(d.getLastConnTime()) < 5*time.Minute {
return nil
_, release, err := d.acquireConn(ctx)
if release != nil {
release()
}
if d.fs != nil {
_ = d.fs.Umount()
return err
}

func (d *SMB) acquireConn(ctx context.Context) (*smb2.Share, func(), error) {
d.connMu.Lock()
defer d.connMu.Unlock()

if d.fs == nil || (time.Since(d.getLastConnTime()) >= 5*time.Minute && d.activeOps == 0) {
if err := d.initFSLocked(ctx); err != nil {
return nil, nil, err
}
}
if d.fs == nil {
return nil, nil, errors.New("smb share is not initialized")
}
d.activeOps++
return d.fs, d.releaseConn, nil
}

func (d *SMB) releaseConn() {
d.connMu.Lock()
defer d.connMu.Unlock()
if d.activeOps > 0 {
d.activeOps--
}
return d.initFS(ctx)
}

// CopyFile File copies a single file from src to dst
Expand Down