timerwheel.go 3.62 KB
package components

import (
	"container/list"
	"sync"
	"sync/atomic"
	"time"
)

//skynet的时间轮 + 协程池
const (
	TimeNearShift  = 8
	TimeNear       = 1 << TimeNearShift
	TimeLevelShift = 6
	TimeLevel      = 1 << TimeLevelShift
	TimeNearMask   = TimeNear - 1
	TimeLevelMask  = TimeLevel - 1

	//协程池 大小
	WorkerPoolSize = 10
	MaxTaskPerWorker = 20
)

type bucket struct {
	expiration int32
	timers *list.List

	mu sync.Mutex
}

func newBucket() *bucket {
	return &bucket{
		expiration: -1,
		timers:     list.New(),
		mu: sync.Mutex{},
	}
}

func (b*bucket) Add(t *timer)  {
	b.mu.Lock()
	defer b.mu.Unlock()

	b.timers.PushBack(t)
}

func (b*bucket) Flush(reinsert func(t *timer))  {
	b.mu.Lock()
	defer b.mu.Unlock()

	for e := b.timers.Front(); e != nil; {
		next := e.Next()
		reinsert(e.Value.(*timer))

		b.timers.Remove(e)
		e = next
	}
}

type timer struct {
	expiration	 uint32
	f func()
}

var TimingWheel *TimeWheel

func init()  {
	TimingWheel = NewTimeWheel()
	TimingWheel.Start()
}

type TimeWheel struct {
	tick time.Duration
	ticker *time.Ticker
	near 			[TimeNear]*bucket
	t				[4][TimeLevel]*bucket
	time 			uint32

	WorkPool *WorkPool
	exit chan struct{}
	exitFlag uint32
}

func NewTimeWheel() *TimeWheel {
	tw := &TimeWheel{
		tick:     10*time.Millisecond,
		time:     0,
		WorkPool: NewWorkPool(WorkerPoolSize, MaxTaskPerWorker),
		exit:     make(chan struct{}),
		exitFlag: 0,
	}
	for i :=0; i < TimeNear; i++ {
		tw.near[i] = newBucket()
	}

	for i :=0; i < 4; i++ {
		for j :=0; j < TimeLevel; j++ {
			tw.t[i][j] = newBucket()
		}
	}
	return tw
}

func (tw *TimeWheel) add(t *timer) bool {
	time := t.expiration
	currentTime := atomic.LoadUint32(&tw.time)
	if time <= currentTime {
		return false
	}

	if (time | TimeNearMask) == (currentTime | TimeNearMask) {
		tw.near[time&TimeNearMask].Add(t)
	}else {
		i := 0
		mask := TimeNear << TimeNearShift
		for i=0; i < 3; i ++ {
			if (time | uint32(mask - 1)) == (currentTime | uint32(mask - 1)) {
				break
			}
			mask <<= TimeLevelShift
		}

		tw.t[i][((time>>(TimeNearShift + i*TimeLevelShift)) & TimeLevelMask)].Add(t)
	}
	return true
}

func (tw *TimeWheel) addOrRun(t *timer)  {
	if !tw.add(t) {
		workerID := int64(t.expiration) % tw.WorkPool.WorkerPoolSize
		//将请求消息发送给任务队列
		tw.WorkPool.TaskQueue[workerID] <- t.f
	}
}

func (tw *TimeWheel) moveList(level, idx int)  {
	current := tw.t[level][idx]
	current.Flush(tw.addOrRun)
}

func (tw *TimeWheel) shift()  {
	mask := TimeNear
	ct := atomic.AddUint32(&tw.time, 1)
	if ct == 0 {
		tw.moveList(3, 0)
	}else {
		time := ct >> TimeNearShift

		i := 0
		for (ct & uint32(mask-1)) == 0{
			idx := time & TimeLevelMask
			if idx != 0 {
				tw.moveList(i, int(idx))
				break
			}

			mask <<= TimeLevelShift
			time >>= TimeLevelShift
			i++
		}
	}
}

func (tw *TimeWheel) execute()  {
	idx := tw.time & TimeNearMask
	tw.near[idx].Flush(tw.addOrRun)
}

func (tw *TimeWheel) update()  {
	tw.execute()
	tw.shift()
	tw.execute()
}

func (tw *TimeWheel) Start()  {
	tw.ticker = time.NewTicker(tw.tick)
	tw.WorkPool.StartWorkerPool()

	go func() {
		for  {
			select {
			case <- tw.ticker.C:
				tw.update()
			case <- tw.exit:
				return
			}
		}
	}()
}

func (tw *TimeWheel) Stop()  {
	flag := atomic.LoadUint32(&tw.exitFlag)
	if flag != 0 {
		return
	}

	atomic.StoreUint32(&tw.exitFlag, 1)
	close(tw.exit)
}

func (tw *TimeWheel) afterFunc(expiration time.Duration, f func()) {
	time := atomic.LoadUint32(&tw.time)
	tw.addOrRun(&timer{
		expiration: uint32(expiration / tw.tick) + time,
		f:          f,
	})
}

func TimeOut(expire time.Duration, f func()) {
	TimingWheel.afterFunc(expire, f)
}

func StopTimer()  {
	TimingWheel.Stop()
}