timewheel.go 4.23 KB
package timewheel

import (
	"pro2d/src/common"
	"pro2d/src/components/workpool"
	"sync/atomic"
	"time"
	"unsafe"
)

type TimeWheel struct {
	ticker    *time.Ticker
	tickMs    int64 //一滴答的时间 1ms 可以自定义 我们这里选择使用1ms
	wheelSize int64
	startMs   int64 //开始时间 in millisecond
	endMs     int64
	wheelTime int64 //跑完一圈所需时间
	level 	  int64 //层级

	//时间刻度 列表
	bucket []*bucket
	currentTime int64 //当前时间 in millisecond
	prevflowWheel unsafe.Pointer // type: *TimingWheel
	overflowWheel unsafe.Pointer // type: *TimingWheel
	exitC     chan struct{}

	WorkPool *workpool.WorkPool
}

func NewTimeWheel(tick time.Duration, wheelSize int64) *TimeWheel {
	//转化为毫秒
	tickMs := int64(tick / time.Millisecond)
	//如果小于零
	if tickMs <=0 {
		panic("tick must be greater than or equal to 1 ms")
	}

	startMs := time.Now().UnixMilli() //ms

	workpool := workpool.NewWorkPool(common.WorkerPoolSize, common.MaxTaskPerWorker)
	return newTimingWheel(tickMs, wheelSize, startMs, 0, nil, workpool)
}

func newTimingWheel(tick, wheelSize int64, start, level int64, prev *TimeWheel, pool *workpool.WorkPool) *TimeWheel {
	buckets := make([]*bucket, wheelSize)
	for i := range buckets {
		buckets[i] = newBucket()
	}

	return &TimeWheel{
		tickMs:      tick,
		wheelSize:   wheelSize,
		startMs:     start,
		endMs:       wheelSize * tick + start,
		wheelTime:   wheelSize * tick,
		bucket:      buckets,
		currentTime: truncate(start, tick),
		exitC:       make(chan struct{}),
		WorkPool: pool,

		prevflowWheel: unsafe.Pointer(prev),
		level: level,
	}
}

func truncate(dst, m int64) int64 {
	return dst - dst%m
}

func (tw *TimeWheel) add(t *Timer) bool {
	currentTime := atomic.LoadInt64(&tw.currentTime)
	if t.expiration < currentTime + tw.tickMs {
		return false
	}else if t.expiration < currentTime + tw.wheelTime {
		virtualID := t.expiration / tw.tickMs  //需要多少滴答数
		b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
		b.Add(t)

		b.SetExpiration(virtualID * tw.tickMs)
	}else {
		overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
		if overflowWheel == nil {
			level := atomic.LoadInt64(&tw.level) + 1
			atomic.CompareAndSwapPointer(
				&tw.overflowWheel,
				nil,
				unsafe.Pointer(newTimingWheel(tw.wheelTime, tw.wheelSize, currentTime, level, tw , tw.WorkPool)),
				)
			overflowWheel = atomic.LoadPointer(&tw.overflowWheel)
		}
		//递归添加到下一级定时器中
		(*TimeWheel)(overflowWheel).add(t)
	}

	return true
}

func (tw *TimeWheel) addOrRun(t *Timer) {
	if !tw.add(t) {
		workerID := t.expiration % tw.WorkPool.WorkerPoolSize
		//将请求消息发送给任务队列
		tw.WorkPool.TaskQueue[workerID] <- t.task
	}
}

//拨动时钟
func (tw *TimeWheel) advanceClock(expiration int64) {
	level := atomic.LoadInt64(&tw.level)
	currentTime := truncate(expiration, tw.tickMs)
	atomic.StoreInt64(&tw.currentTime, currentTime)

	if level == 0 {
		virtualID := expiration / tw.tickMs    //需要多少滴答数
		b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
		b.Flush(tw.addOrRun)
	} else {
		prevflowWheel := atomic.LoadPointer(&tw.prevflowWheel)
		if prevflowWheel != nil {
			virtualID := expiration / tw.tickMs    //需要多少滴答数
			b := tw.bucket[virtualID%tw.wheelSize] //pos = 所需滴答数 % wheelSize
			b.Flush((*TimeWheel)(prevflowWheel).addOrRun)
		}
	}

	//如果基础的时钟指针转完了一圈,则递归拨动下一级时钟
	if currentTime >= tw.endMs {
		atomic.StoreInt64(&tw.startMs, currentTime)
		atomic.StoreInt64(&tw.endMs, currentTime + tw.wheelTime)

		overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
		if overflowWheel != nil {
			(*TimeWheel)(overflowWheel).advanceClock(currentTime)
		}
	}
}


func (tw *TimeWheel) AfterFunc(d time.Duration, f func()) *Timer {
	t := &Timer{
		expiration: time.Now().UTC().Add(d).UnixMilli(),
		task:       f,
	}
	tw.addOrRun(t)
	return t
}

func (tw *TimeWheel) Start()  {
	tw.ticker = time.NewTicker(time.Duration(tw.tickMs) * time.Millisecond)
	tw.WorkPool.StartWorkerPool()

	go func() {
		for {
			select {
			case t := <- tw.ticker.C:
				tw.advanceClock(t.UnixMilli())
			case <- tw.exitC:
				return
			}
		}
	}()
}

func (tw *TimeWheel) Stop()  {
	tw.exitC <- struct{}{}
}