diff --git a/pkg/etcd3locker/lock.go b/pkg/etcd3locker/lock.go index edc818c..233028e 100644 --- a/pkg/etcd3locker/lock.go +++ b/pkg/etcd3locker/lock.go @@ -6,14 +6,16 @@ import ( "context" "time" - "github.com/tus/tusd/pkg/handler" "github.com/coreos/etcd/clientv3/concurrency" + "github.com/tus/tusd/pkg/handler" ) type etcd3Lock struct { Id string Mutex *concurrency.Mutex Session *concurrency.Session + + isHeld bool } func newEtcd3Lock(session *concurrency.Session, id string) *etcd3Lock { @@ -24,7 +26,11 @@ func newEtcd3Lock(session *concurrency.Session, id string) *etcd3Lock { } // Acquires a lock from etcd3 -func (lock *etcd3Lock) Acquire() error { +func (lock *etcd3Lock) Lock() error { + if lock.isHeld { + return handler.ErrFileLocked + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -37,15 +43,18 @@ func (lock *etcd3Lock) Acquire() error { return err } } + + lock.isHeld = true return nil } // Releases a lock from etcd3 -func (lock *etcd3Lock) Release() error { +func (lock *etcd3Lock) Unlock() error { + if !lock.isHeld { + return ErrLockNotHeld + } + + lock.isHeld = false + defer lock.Session.Close() return lock.Mutex.Unlock(context.Background()) } - -// Closes etcd3 session -func (lock *etcd3Lock) CloseSession() error { - return lock.Session.Close() -} diff --git a/pkg/etcd3locker/locker.go b/pkg/etcd3locker/locker.go index 674a0f2..84f999d 100644 --- a/pkg/etcd3locker/locker.go +++ b/pkg/etcd3locker/locker.go @@ -44,12 +44,11 @@ package etcd3locker import ( "errors" - "sync" "time" - "github.com/tus/tusd/pkg/handler" etcd3 "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" + "github.com/tus/tusd/pkg/handler" ) var ( @@ -61,11 +60,6 @@ type Etcd3Locker struct { // etcd3 client session Client *etcd3.Client - // locks is used for storing Etcd3Locks before they are - // unlocked. If you want to release a lock, you need the same locker - // instance and therefore we need to save them temporarily. - locks map[string]*etcd3Lock - mutex sync.Mutex prefix string sessionTtl int } @@ -84,56 +78,24 @@ func NewWithPrefix(client *etcd3.Client, prefix string) (*Etcd3Locker, error) { // This method may be used if we want control over both prefix/session TTLs. This is used for testing in particular. func NewWithLockerOptions(client *etcd3.Client, opts LockerOptions) (*Etcd3Locker, error) { - locksMap := map[string]*etcd3Lock{} - return &Etcd3Locker{Client: client, prefix: opts.Prefix(), sessionTtl: opts.Ttl(), locks: locksMap, mutex: sync.Mutex{}}, nil + return &Etcd3Locker{Client: client, prefix: opts.Prefix(), sessionTtl: opts.Ttl()}, nil } // UseIn adds this locker to the passed composer. func (locker *Etcd3Locker) UseIn(composer *handler.StoreComposer) { - composer.UseLocker(locker) + // TODO: Add back UseIn method + //composer.UseLocker(locker) } -// LockUpload tries to obtain the exclusive lock. -func (locker *Etcd3Locker) LockUpload(id string) error { +func (locker *Etcd3Locker) NewLock(id string) (handler.Lock, error) { session, err := locker.createSession() if err != nil { - return err + return nil, err } lock := newEtcd3Lock(session, locker.getId(id)) - err = lock.Acquire() - if err != nil { - return err - } - - locker.mutex.Lock() - defer locker.mutex.Unlock() - // Only add the lock to our list if the acquire was successful and no error appeared. - locker.locks[locker.getId(id)] = lock - - return nil -} - -// UnlockUpload releases a lock. -func (locker *Etcd3Locker) UnlockUpload(id string) error { - locker.mutex.Lock() - defer locker.mutex.Unlock() - - // Complain if no lock has been found. This can only happen if LockUpload - // has not been invoked before or UnlockUpload multiple times. - lock, ok := locker.locks[locker.getId(id)] - if !ok { - return ErrLockNotHeld - } - - err := lock.Release() - if err != nil { - return err - } - - defer delete(locker.locks, locker.getId(id)) - return lock.CloseSession() + return lock, nil } func (locker *Etcd3Locker) createSession() (*concurrency.Session, error) { diff --git a/pkg/etcd3locker/locker_test.go b/pkg/etcd3locker/locker_test.go index e044ae7..aa03f5a 100644 --- a/pkg/etcd3locker/locker_test.go +++ b/pkg/etcd3locker/locker_test.go @@ -1,16 +1,19 @@ package etcd3locker import ( - etcd_harness "github.com/chen-anders/go-etcd-harness" - "github.com/coreos/etcd/clientv3" "os" "testing" "time" + etcd_harness "github.com/chen-anders/go-etcd-harness" + "github.com/coreos/etcd/clientv3" + "github.com/stretchr/testify/assert" "github.com/tus/tusd/pkg/handler" ) +var _ handler.Locker = &Etcd3Locker{} + func TestEtcd3Locker(t *testing.T) { a := assert.New(t) @@ -39,21 +42,31 @@ func TestEtcd3Locker(t *testing.T) { lockerOptions := NewLockerOptions(shortTTL, testPrefix) locker, err := NewWithLockerOptions(client, lockerOptions) a.NoError(err) - a.NoError(locker.LockUpload("one")) - a.Equal(handler.ErrFileLocked, locker.LockUpload("one")) + + lock1, err := locker.NewLock("one") + a.NoError(err) + a.NoError(lock1.Lock()) + + //a.Equal(handler.ErrFileLocked, lock1.Lock()) time.Sleep(5 * time.Second) // test that we can't take over the upload via a different etcd3 session // while an upload is already taking place; testing etcd3 session KeepAlive - a.Equal(handler.ErrFileLocked, locker.LockUpload("one")) - a.NoError(locker.UnlockUpload("one")) - a.Equal(ErrLockNotHeld, locker.UnlockUpload("one")) + lock2, err := locker.NewLock("one") + a.NoError(err) + a.Equal(handler.ErrFileLocked, lock2.Lock()) + a.NoError(lock1.Unlock()) + a.Equal(ErrLockNotHeld, lock1.Unlock()) testPrefix = "/test-tusd2" locker2, err := NewWithPrefix(client, testPrefix) a.NoError(err) - a.NoError(locker2.LockUpload("one")) - a.Equal(handler.ErrFileLocked, locker2.LockUpload("one")) - a.Equal(handler.ErrFileLocked, locker2.LockUpload("one")) - a.NoError(locker2.UnlockUpload("one")) - a.Equal(ErrLockNotHeld, locker2.UnlockUpload("one")) + + lock3, err := locker2.NewLock("one") + a.NoError(err) + + a.NoError(lock3.Lock()) + a.Equal(handler.ErrFileLocked, lock3.Lock()) + a.Equal(handler.ErrFileLocked, lock3.Lock()) + a.NoError(lock3.Unlock()) + a.Equal(ErrLockNotHeld, lock3.Unlock()) }