一个golang并行库源码解析

    xiaoxiao2026-01-26  8

    场景

    有这样一种场景:四个任务A、B、C, D,其中任务B和C需要并发执行,得到结果1, 任务A执行得到结果2, 结果1和2作为任务D的参数传入,然后执行任务D得到最终结果。我们可以将任务执行顺序用如下图标识:

    jobA jobB jobC \ \ / \ \ / \ middle \ / \ / jobD

    这是一个典型的多任务并发场景,实际上随着任务数量的增多,任务逻辑会更加复杂,如何编写可维护健壮的逻辑代码变得十分重要,虽然golang提供了同步机制,但是需要写很多重复无用的Add/Wait/Done代码,而且代码可读性也很差,这是不能容忍的。

    本文介绍一个开源的golang并行库,源码地址https://github.com/buptmiao/parallel

    数据结构

    1. parallel结构体

    type Parallel struct { wg *sync.WaitGroup pipes []*Pipeline wgChild *sync.WaitGroup children []*Parallel exception *Handler }

    parallel定义了一个多任务并发实例,主要包括:并发任务管道(pipes)、子任务并发实例(children)、子任务实例等待锁(wgChild)、当前并发任务实例等待锁(wg)

    2. pipeline结构体

    type Pipeline struct { handlers []*Handler } type Handler struct { f interface{} args []interface{} receivers []interface{} }

    这里pipeline实际上是一系列并发任务实例handler,每一个handler包括任务函数f, 传入参数args以及返回结果receivers

    parallel相关代码

    新建parallel实例

    func NewParallel() *Parallel { res := new(Parallel) res.wg = new(sync.WaitGroup) res.wgChild = new(sync.WaitGroup) res.pipes = make([]*Pipeline, 0, 10) return res }

    注册handler

    func (p *Parallel) Register(f interface{}, args ...interface{}) *Handler { return p.NewPipeline().Register(f, args...) } func (p *Parallel) NewPipeline() *Pipeline { pipe := NewPipeline() p.Add(pipe) return pipe } func (p *Parallel) Add(pipes ...*Pipeline) *Parallel { p.wg.Add(len(pipes)) p.pipes = append(p.pipes, pipes...) return p }

    新建子parallel实例

    func (p *Parallel) NewChild() *Parallel { child := NewParallel() child.exception = p.exception p.AddChildren(child) return child } func (p *Parallel) AddChildren(children ...*Parallel) *Parallel { p.wgChild.Add(len(children)) p.children = append(p.children, children...) return p }

    任务运行

    func (p *Parallel) Run() { for _, child := range p.children { // this func will never panic go func(ch *Parallel) { ch.Run() p.wgChild.Done() }(child) } p.wgChild.Wait() //wait children instance done p.do() //run p.wg.Wait() //wait all job done } func (p *Parallel) do() { for _, pipe := range p.pipes { go p.Do() } }

    pipeline相关代码

    新建pipeline实例

    func NewPipeline() *Pipeline { res := new(Pipeline) return res }

    注册handler

    func (p *Pipeline) Register(f interface{}, args ...interface{}) *Handler { h := NewHandler(f, args...) p.Add(h) return h }

    添加handler

    func (p *Pipeline) Add(hs ...*Handler) *Pipeline { p.handlers = append(p.handlers, hs...) return p }

    任务运行

    func (p *Pipeline) Do() { for _, h := range p.handlers { h.Do() } }

    handler相关代码

    新建handler实例

    func NewHandler(f interface{}, args ...interface{}) *Handler { res := new(Handler) res.f = f res.args = args return res }

    运行任务

    func (h *Handler) Do() { f := reflect.ValueOf(h.f) typ := f.Type() //check if f is a function if typ.Kind() != reflect.Func { panic(ErrArgNotFunction) } //check input length, only check '>' is to allow varargs. if typ.NumIn() > len(h.args) { panic(ErrInArgLenNotMatch) } //check output length if typ.NumOut() != len(h.receivers) { panic(ErrOutArgLenNotMatch) } //check if output args is ptr for _, v := range h.receivers { t := reflect.ValueOf(v) if t.Type().Kind() != reflect.Ptr { panic(ErrRecvArgTypeNotPtr) } if t.IsNil() { panic(ErrRecvArgNil) } } inputs := make([]reflect.Value, len(h.args)) for i := 0; i < len(h.args); i++ { if h.args[i] == nil { inputs[i] = reflect.Zero(f.Type().In(i)) } else { inputs[i] = reflect.ValueOf(h.args[i]) } } out := f.Call(inputs) for i := 0; i < len(h.receivers); i++ { v := reflect.ValueOf(h.receivers[i]) v.Elem().Set(out[i]) } }

    demo

    package main import "github.com/buptmiao/parallel" func testJobA(x, y int) int { return x - y } func testJobB(x, y int) int { return x + y } func testJobC(x, y *int, z int) float64 { return float64((*x)*(*y)) / float64(z) } func main() { var x, y int var z float64 p := parallel.NewParallel() ch1 := p.NewChild() ch1.Register(testJobA, 1, 2).SetReceivers(&x) ch2 := p.NewChild() ch2.Register(testJobB, 1, 2).SetReceivers(&y) p.Register(testJobC, &x, &y, 2).SetReceivers(&z) p.Run() if x != -1 || y != 3 || z != -1.5 { panic("unexpected result") } } 相关资源:panicparse:使您的应用程序崩溃(Golang)-源码
    最新回复(0)