Added get by interface method
This commit is contained in:
parent
d3ad517530
commit
a84b0dc4a3
3 changed files with 80 additions and 16 deletions
16
README.md
16
README.md
|
@ -17,12 +17,17 @@ di.Register("service id", func () (*Service, error) { /* construct service */ })
|
|||
Get dependencies by type:
|
||||
|
||||
```go
|
||||
services, err := di.Get[Service]()
|
||||
services, err := di.GetByType[Service]()
|
||||
```
|
||||
|
||||
Get dependencies by type and id:
|
||||
```go
|
||||
service, err := di.GetById[Service]("service id")
|
||||
service, err := di.Get[Service]("service id")
|
||||
```
|
||||
|
||||
Get dependencies by interface:
|
||||
```go
|
||||
services, err := di.GetByInterface[Worker]() // Worker is interface for many workers
|
||||
```
|
||||
|
||||
### Go doc
|
||||
|
@ -30,8 +35,9 @@ service, err := di.GetById[Service]("service id")
|
|||
```go
|
||||
package di // import "go.neonxp.dev/di"
|
||||
|
||||
func Get[T any]() ([]*T, error)
|
||||
func GetById[T any](id string) (*T, error)
|
||||
func Get[T any](id string) (*T, error)
|
||||
func GetByInterface[Interface any]() ([]Interface, error)
|
||||
func GetByType[T any]() ([]*T, error)
|
||||
func Register[T any](id string, constructor func() (*T, error))
|
||||
```
|
||||
|
||||
|
@ -53,7 +59,7 @@ di.Register("serviceB", func() (*ServiceB, error) { // <- Register service B, th
|
|||
})
|
||||
|
||||
// Do work ...
|
||||
service, err := di.GetById[ServiceB]("serviceB") // <- Get instantinated service B
|
||||
service, err := di.Get[ServiceB]("serviceB") // <- Get instantinated service B
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
44
di.go
44
di.go
|
@ -15,8 +15,13 @@ func init() {
|
|||
cache = sync.Map{}
|
||||
}
|
||||
|
||||
// Register service in di
|
||||
func Register[T any](id string, constructor func() (*T, error)) {
|
||||
services.Store(id, constructor)
|
||||
}
|
||||
|
||||
// Get services by type
|
||||
func Get[T any]() ([]*T, error) {
|
||||
func GetByType[T any]() ([]*T, error) {
|
||||
var err error
|
||||
result := []*T{}
|
||||
services.Range(func(id, constructor any) bool {
|
||||
|
@ -40,13 +45,38 @@ func Get[T any]() ([]*T, error) {
|
|||
return result, err
|
||||
}
|
||||
|
||||
// Get service by type and id
|
||||
func GetById[T any](id string) (*T, error) {
|
||||
// Get services by interface
|
||||
func GetByInterface[Interface any]() ([]Interface, error) {
|
||||
var err error
|
||||
result := []Interface{}
|
||||
services.Range(func(id, constructor any) bool {
|
||||
if constructor, ok := constructor.(func() (Interface, error)); ok {
|
||||
if instance, ok := cache.Load(id); ok {
|
||||
if instance, ok := instance.(Interface); ok {
|
||||
result = append(result, instance)
|
||||
}
|
||||
return true
|
||||
}
|
||||
instance, instErr := constructor()
|
||||
if instErr != nil {
|
||||
err = instErr
|
||||
return false
|
||||
}
|
||||
cache.Store(id, instance)
|
||||
result = append(result, instance)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Get service by id and type
|
||||
func Get[T any](id string) (*T, error) {
|
||||
if instance, ok := cache.Load(id); ok {
|
||||
if instance, ok := instance.(*T); ok {
|
||||
return instance, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid type %t for service %s", instance, id)
|
||||
return nil, fmt.Errorf("invalid type for service %s (%t)", id, instance)
|
||||
}
|
||||
if constructor, ok := services.Load(id); ok {
|
||||
if constructor, ok := constructor.(func() (*T, error)); ok {
|
||||
|
@ -57,11 +87,7 @@ func GetById[T any](id string) (*T, error) {
|
|||
cache.Store(id, instance)
|
||||
return instance, nil
|
||||
}
|
||||
return nil, fmt.Errorf("invalid type %t for service %s", constructor, id)
|
||||
return nil, fmt.Errorf("invalid constructor")
|
||||
}
|
||||
return nil, fmt.Errorf("unknown service %s", id)
|
||||
}
|
||||
|
||||
func Register[T any](id string, constructor func() (*T, error)) {
|
||||
services.Store(id, constructor)
|
||||
}
|
||||
|
|
36
di_test.go
36
di_test.go
|
@ -11,7 +11,7 @@ func ExampleGet() {
|
|||
return &ServiceA{}, nil
|
||||
})
|
||||
di.Register("serviceB", func() (*ServiceB, error) { // <- Register service B, that depends from service A
|
||||
serviceA, err := di.Get[ServiceA]() // <- Get dependency from container by type
|
||||
serviceA, err := di.GetByType[ServiceA]() // <- Get dependency from container by type
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -22,13 +22,29 @@ func ExampleGet() {
|
|||
})
|
||||
|
||||
// Do work...
|
||||
service, err := di.GetById[ServiceB]("serviceB") // <- Get instantinated service B
|
||||
service, err := di.Get[ServiceB]("serviceB") // <- Get instantinated service B
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
service.DoStuff() // Output: Hello, world!
|
||||
}
|
||||
|
||||
func ExampleGet_interface() {
|
||||
di.Register("worker1", func() (*Worker1, error) {
|
||||
return &Worker1{}, nil
|
||||
})
|
||||
di.Register("worker2", func() (*Worker2, error) {
|
||||
return &Worker2{}, nil
|
||||
})
|
||||
workers, err := di.GetByInterface[Worker]()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, w := range workers {
|
||||
w.Do()
|
||||
}
|
||||
}
|
||||
|
||||
type ServiceA struct{}
|
||||
|
||||
func (d *ServiceA) DoStuff() {
|
||||
|
@ -42,3 +58,19 @@ type ServiceB struct {
|
|||
func (d *ServiceB) DoStuff() {
|
||||
d.ServiceA.DoStuff()
|
||||
}
|
||||
|
||||
type Worker interface {
|
||||
Do()
|
||||
}
|
||||
|
||||
type Worker1 struct{}
|
||||
|
||||
func (w *Worker1) Do() {
|
||||
fmt.Println("Worker 1 says hello")
|
||||
}
|
||||
|
||||
type Worker2 struct{}
|
||||
|
||||
func (w *Worker2) Do() {
|
||||
fmt.Println("Worker 2 says hello")
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue