diff --git a/example/main.go b/example/main.go index 11e1fb0..37b9037 100644 --- a/example/main.go +++ b/example/main.go @@ -7,7 +7,6 @@ import ( "os" "os/signal" - "github.com/qri-io/jsonschema" "go.neonxp.dev/jsonrpc2/rpc" "go.neonxp.dev/jsonrpc2/rpc/middleware" "go.neonxp.dev/jsonrpc2/transport" @@ -18,10 +17,12 @@ func main() { rpc.WithLogger(rpc.StdLogger), rpc.WithTransport(&transport.HTTP{Bind: ":8000", CORSOrigin: "*"}), ) + // Set options after constructor - validation, err := middleware.Validation(map[string]middleware.MethodSchema{ + serviceSchema := ` + { "divide": { - Request: *jsonschema.Must(`{ + "request": { "type": "object", "properties": { "a": { @@ -33,8 +34,8 @@ func main() { } }, "required": ["a", "b"] - }`), - Response: *jsonschema.Must(`{ + }, + "response": { "type": "object", "properties": { "quo": { @@ -45,9 +46,28 @@ func main() { } }, "required": ["quo", "rem"] - }`), + } }, - }) + "multiply": { + "request": { + "type": "object", + "properties": { + "a": { + "type": "integer" + }, + "b": { + "type": "integer" + } + }, + "required": ["a", "b"] + }, + "response": { + "type": "integer" + } + } + }` + + validation, err := middleware.Validation(middleware.MustSchema(serviceSchema)) if err != nil { log.Fatal(err) } diff --git a/rpc/middleware/validation.go b/rpc/middleware/validation.go index e994383..80b4cfd 100644 --- a/rpc/middleware/validation.go +++ b/rpc/middleware/validation.go @@ -30,25 +30,37 @@ import ( "go.neonxp.dev/jsonrpc2/rpc" ) -type MethodSchema struct { - Request jsonschema.Schema - Response jsonschema.Schema +type ServiceSchema map[string]MethodSchema + +func MustSchema(schema string) ServiceSchema { + ss := new(ServiceSchema) + if err := json.Unmarshal([]byte(schema), ss); err != nil { + panic(err) + } + return *ss } -func Validation(serviceSchema map[string]MethodSchema) (rpc.Middleware, error) { +type MethodSchema struct { + Request *jsonschema.Schema `json:"request"` + Response *jsonschema.Schema `json:"response"` +} + +func Validation(serviceSchema ServiceSchema) (rpc.Middleware, error) { return func(handler rpc.RpcHandler) rpc.RpcHandler { return func(ctx context.Context, req *rpc.RpcRequest) *rpc.RpcResponse { - if rs, ok := serviceSchema[strings.ToLower(req.Method)]; ok { - if errResp := formatError(ctx, req.Id, rs.Request, req.Params); errResp != nil { + rs, hasSchema := serviceSchema[strings.ToLower(req.Method)] + if hasSchema && rs.Request != nil { + if errResp := formatError(ctx, req.Id, *rs.Request, req.Params); errResp != nil { return errResp } - resp := handler(ctx, req) - if errResp := formatError(ctx, req.Id, rs.Response, resp.Result); errResp != nil { - return errResp - } - return resp } - return handler(ctx, req) + resp := handler(ctx, req) + if hasSchema && rs.Response != nil { + if errResp := formatError(ctx, req.Id, *rs.Response, resp.Result); errResp != nil { + return errResp + } + } + return resp } }, nil }