package main

import (
	"fmt"
	"sync"
	"sync/atomic"
	"time"
)

// call represents an in-flight (calls in progress)
type call struct {
	done chan struct{}
	val  string
	err  error
}

type Group struct {
	mu    sync.Mutex
	calls map[string]*call
}

func (g *Group) Do(key string, fn func() (string, error)) (string, bool, error) {
	g.mu.Lock()

	if g.calls == nil {
		g.calls = make(map[string]*call)
	}

	if c, ok := g.calls[key]; ok {
		g.mu.Unlock()

		<-c.done
		return c.val, true, c.err
	}

	c := &call{
		done: make(chan struct{}),
	}

	g.calls[key] = c
	g.mu.Unlock()

	c.val, c.err = fn()

	close(c.done)

	g.mu.Lock()
	delete(g.calls, key)
	g.mu.Unlock()

	return c.val, false, c.err
}

var expensiveCalls atomic.Int64

func expensiveFunc() (string, error) {
	n := expensiveCalls.Add(1)

	time.Sleep(300 * time.Millisecond)

	return fmt.Sprintf("result from expensive call #%d", n), nil
}

func main() {
	var g Group

	var wg sync.WaitGroup

	// Start 6 goroutines that call the same key at the same time.
	for i := range 6 {
		wg.Add(1)

		go func(id int) {
			defer wg.Done()

			val, shared, err := g.Do("same-key", expensiveFunc)
			if err != nil {
				fmt.Printf("worker=%d error=%v\n", id, err)
				return
			}

			fmt.Printf("worker=%d shared=%v value=%q\n", id, shared, val)
		}(i)
	}

	wg.Wait()

	fmt.Printf("expensive calls: %d\n", expensiveCalls.Load())

	// Do it again to see that the result is not shared and the expensive function is called again.
	val, shared, err := g.Do("same-key", expensiveFunc)
	if err != nil {
		fmt.Printf("worker=%d error=%v\n", 0, err)
		return
	}

	fmt.Printf("worker=%d shared=%v value=%q\n", 0, shared, val)

	fmt.Printf("expensive calls: %d\n", expensiveCalls.Load())

}
