diff --git a/README.md b/README.md index c20db6e..70745a9 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,17 @@ di.Register("service id", func () (*Service, error) { /* construct service */ }) Get dependencies by type: ```go -services, err := di.Get[Service]() +services, err := di.GetByType[Service]() ``` Get dependencies by type and id: ```go -service, err := di.GetById[Service]("service id") +service, err := di.Get[Service]("service id") +``` + +Get dependencies by interface: +```go +services, err := di.GetByInterface[Worker]() // Worker is interface for many workers ``` ### Go doc @@ -30,8 +35,9 @@ service, err := di.GetById[Service]("service id") ```go package di // import "go.neonxp.dev/di" -func Get[T any]() ([]*T, error) -func GetById[T any](id string) (*T, error) +func Get[T any](id string) (*T, error) +func GetByInterface[Interface any]() ([]Interface, error) +func GetByType[T any]() ([]*T, error) func Register[T any](id string, constructor func() (*T, error)) ``` @@ -53,7 +59,7 @@ di.Register("serviceB", func() (*ServiceB, error) { // <- Register service B, th }) // Do work ... -service, err := di.GetById[ServiceB]("serviceB") // <- Get instantinated service B +service, err := di.Get[ServiceB]("serviceB") // <- Get instantinated service B if err != nil { panic(err) } diff --git a/di.go b/di.go index b015577..8e41ee9 100644 --- a/di.go +++ b/di.go @@ -15,8 +15,13 @@ func init() { cache = sync.Map{} } +// Register service in di +func Register[T any](id string, constructor func() (*T, error)) { + services.Store(id, constructor) +} + // Get services by type -func Get[T any]() ([]*T, error) { +func GetByType[T any]() ([]*T, error) { var err error result := []*T{} services.Range(func(id, constructor any) bool { @@ -40,13 +45,38 @@ func Get[T any]() ([]*T, error) { return result, err } -// Get service by type and id -func GetById[T any](id string) (*T, error) { +// Get services by interface +func GetByInterface[Interface any]() ([]Interface, error) { + var err error + result := []Interface{} + services.Range(func(id, constructor any) bool { + if constructor, ok := constructor.(func() (Interface, error)); ok { + if instance, ok := cache.Load(id); ok { + if instance, ok := instance.(Interface); ok { + result = append(result, instance) + } + return true + } + instance, instErr := constructor() + if instErr != nil { + err = instErr + return false + } + cache.Store(id, instance) + result = append(result, instance) + } + return true + }) + return result, err +} + +// Get service by id and type +func Get[T any](id string) (*T, error) { if instance, ok := cache.Load(id); ok { if instance, ok := instance.(*T); ok { return instance, nil } - return nil, fmt.Errorf("invalid type %t for service %s", instance, id) + return nil, fmt.Errorf("invalid type for service %s (%t)", id, instance) } if constructor, ok := services.Load(id); ok { if constructor, ok := constructor.(func() (*T, error)); ok { @@ -57,11 +87,7 @@ func GetById[T any](id string) (*T, error) { cache.Store(id, instance) return instance, nil } - return nil, fmt.Errorf("invalid type %t for service %s", constructor, id) + return nil, fmt.Errorf("invalid constructor") } return nil, fmt.Errorf("unknown service %s", id) } - -func Register[T any](id string, constructor func() (*T, error)) { - services.Store(id, constructor) -} diff --git a/di_test.go b/di_test.go index cd0eb23..a67a82e 100644 --- a/di_test.go +++ b/di_test.go @@ -11,7 +11,7 @@ func ExampleGet() { return &ServiceA{}, nil }) di.Register("serviceB", func() (*ServiceB, error) { // <- Register service B, that depends from service A - serviceA, err := di.Get[ServiceA]() // <- Get dependency from container by type + serviceA, err := di.GetByType[ServiceA]() // <- Get dependency from container by type if err != nil { return nil, err } @@ -22,13 +22,29 @@ func ExampleGet() { }) // Do work... - service, err := di.GetById[ServiceB]("serviceB") // <- Get instantinated service B + service, err := di.Get[ServiceB]("serviceB") // <- Get instantinated service B if err != nil { panic(err) } service.DoStuff() // Output: Hello, world! } +func ExampleGet_interface() { + di.Register("worker1", func() (*Worker1, error) { + return &Worker1{}, nil + }) + di.Register("worker2", func() (*Worker2, error) { + return &Worker2{}, nil + }) + workers, err := di.GetByInterface[Worker]() + if err != nil { + panic(err) + } + for _, w := range workers { + w.Do() + } +} + type ServiceA struct{} func (d *ServiceA) DoStuff() { @@ -42,3 +58,19 @@ type ServiceB struct { func (d *ServiceB) DoStuff() { d.ServiceA.DoStuff() } + +type Worker interface { + Do() +} + +type Worker1 struct{} + +func (w *Worker1) Do() { + fmt.Println("Worker 1 says hello") +} + +type Worker2 struct{} + +func (w *Worker2) Do() { + fmt.Println("Worker 2 says hello") +}