diff --git a/go.mod b/go.mod index c3249a8..571b21e 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,3 @@ require ( golang.org/x/tools v0.1.8 google.golang.org/protobuf v1.27.1 ) - -//replace go.unistack.org/micro/v3 => ../micro -//replace go.unistack.org/micro-proto => ../micro-proto diff --git a/openapiv3.go b/openapiv3.go index 8deb3e0..4c1f255 100644 --- a/openapiv3.go +++ b/openapiv3.go @@ -12,17 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. // + package main import ( "fmt" "log" - "net/http" + "net/url" "regexp" "sort" "strings" - annotations "go.unistack.org/micro-proto/v3/api" + "go.unistack.org/micro-proto/v3/api" v3 "go.unistack.org/micro-proto/v3/openapiv3" "google.golang.org/protobuf/compiler/protogen" jsonpb "google.golang.org/protobuf/encoding/protojson" @@ -30,22 +31,30 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" ) +// openapiv3Generator holds internal state needed to generate an OpenAPIv3 document for a transcoded Protocol Buffer service. type openapiv3Generator struct { - plugin *protogen.Plugin + circularDepth int + naming string + plugin *protogen.Plugin requiredSchemas []string // Names of schemas that need to be generated. generatedSchemas []string // Names of schemas that have already been generated. linterRulePattern *regexp.Regexp - namePattern *regexp.Regexp + pathPattern *regexp.Regexp + namedPathPattern *regexp.Regexp } +// openapiv3Generate creates a new generator for a protoc plugin invocation. func (g *Generator) openapiv3Generate(component string, plugin *protogen.Plugin) error { og := &openapiv3Generator{ + circularDepth: 2, plugin: plugin, + naming: "proto", requiredSchemas: make([]string, 0), generatedSchemas: make([]string, 0), linterRulePattern: regexp.MustCompile(`\(-- .* --\)`), - namePattern: regexp.MustCompile("{(.*)=(.*)}"), + pathPattern: regexp.MustCompile("{([^=}]+)}"), + namedPathPattern: regexp.MustCompile("{(.+)=(.+)}"), } d := og.buildDocumentV3(plugin) @@ -65,13 +74,7 @@ func (g *Generator) openapiv3Generate(component string, plugin *protogen.Plugin) // buildDocumentV3 builds an OpenAPIv3 document for a plugin request. func (g *openapiv3Generator) buildDocumentV3(plugin *protogen.Plugin) *v3.Document { - d := &v3.Document{} - d.Openapi = "3.0.3" - d.Info = &v3.Info{ - Title: "", - Version: "0.0.1", - Description: "", - } + d := &v3.Document{Openapi: "3.0.3", Info: &v3.Info{Version: "0.0.1"}} for _, file := range plugin.Files { if !proto.HasExtension(file.Desc.Options(), v3.E_Openapiv3Swagger) { @@ -101,16 +104,105 @@ func (g *openapiv3Generator) buildDocumentV3(plugin *protogen.Plugin) *v3.Docume AdditionalProperties: []*v3.NamedSchemaOrReference{}, }, } + for _, file := range g.plugin.Files { - g.addPathsToDocumentV3(d, file) + if file.Generate { + g.addPathsToDocumentV3(d, file) + } } + + // If there is only 1 service, then use it's title for the document, + // if the document is missing it. + if len(d.Tags) == 1 { + if d.Info.Title == "" && d.Tags[0].Name != "" { + d.Info.Title = d.Tags[0].Name + " API" + } + if d.Info.Description == "" { + d.Info.Description = d.Tags[0].Description + } + d.Tags[0].Description = "" + } + for len(g.requiredSchemas) > 0 { count := len(g.requiredSchemas) for _, file := range g.plugin.Files { - g.addSchemasToDocumentV3(d, file) + g.addSchemasToDocumentV3(d, file.Messages) } g.requiredSchemas = g.requiredSchemas[count:len(g.requiredSchemas)] } + + allServers := []string{} + + // If paths methods has servers, but they're all the same, then move servers to path level + for _, path := range d.Paths.Path { + servers := []string{} + // Only 1 server will ever be set, per method, by the generator + + if path.Value.Get != nil && len(path.Value.Get.Servers) == 1 { + servers = appendUniuqe(servers, path.Value.Get.Servers[0].Url) + allServers = appendUniuqe(servers, path.Value.Get.Servers[0].Url) + } + if path.Value.Post != nil && len(path.Value.Post.Servers) == 1 { + servers = appendUniuqe(servers, path.Value.Post.Servers[0].Url) + allServers = appendUniuqe(servers, path.Value.Post.Servers[0].Url) + } + if path.Value.Put != nil && len(path.Value.Put.Servers) == 1 { + servers = appendUniuqe(servers, path.Value.Put.Servers[0].Url) + allServers = appendUniuqe(servers, path.Value.Put.Servers[0].Url) + } + if path.Value.Delete != nil && len(path.Value.Delete.Servers) == 1 { + servers = appendUniuqe(servers, path.Value.Delete.Servers[0].Url) + allServers = appendUniuqe(servers, path.Value.Delete.Servers[0].Url) + } + if path.Value.Patch != nil && len(path.Value.Patch.Servers) == 1 { + servers = appendUniuqe(servers, path.Value.Patch.Servers[0].Url) + allServers = appendUniuqe(servers, path.Value.Patch.Servers[0].Url) + } + + if len(servers) == 1 { + path.Value.Servers = []*v3.Server{{Url: servers[0]}} + + if path.Value.Get != nil { + path.Value.Get.Servers = nil + } + if path.Value.Post != nil { + path.Value.Post.Servers = nil + } + if path.Value.Put != nil { + path.Value.Put.Servers = nil + } + if path.Value.Delete != nil { + path.Value.Delete.Servers = nil + } + if path.Value.Patch != nil { + path.Value.Patch.Servers = nil + } + } + } + + // Set all servers on API level + if len(allServers) > 0 { + d.Servers = []*v3.Server{} + for _, server := range allServers { + d.Servers = append(d.Servers, &v3.Server{Url: server}) + } + } + + // If there is only 1 server, we can safely remove all path level servers + if len(allServers) == 1 { + for _, path := range d.Paths.Path { + path.Value.Servers = nil + } + } + + // Sort the tags. + { + pairs := d.Tags + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].Name < pairs[j].Name + }) + d.Tags = pairs + } // Sort the paths. { pairs := d.Paths.Path @@ -131,9 +223,11 @@ func (g *openapiv3Generator) buildDocumentV3(plugin *protogen.Plugin) *v3.Docume } // filterCommentString removes line breaks and linter rules from comments. -func (g *openapiv3Generator) filterCommentString(c protogen.Comments) string { +func (g *openapiv3Generator) filterCommentString(c protogen.Comments, removeNewLines bool) string { comment := string(c) - comment = strings.Replace(comment, "\n", "", -1) + if removeNewLines { + comment = strings.Replace(comment, "\n", "", -1) + } comment = g.linterRulePattern.ReplaceAllString(comment, "") return strings.TrimSpace(comment) } @@ -141,70 +235,238 @@ func (g *openapiv3Generator) filterCommentString(c protogen.Comments) string { // addPathsToDocumentV3 adds paths from a specified file descriptor. func (g *openapiv3Generator) addPathsToDocumentV3(d *v3.Document, file *protogen.File) { for _, service := range file.Services { - comment := g.filterCommentString(service.Comments.Leading) - if d.Info.Title == "" { - d.Info.Title = service.GoName - } - if d.Info.Description == "" { - d.Info.Description = comment - } + annotationsCount := 0 + for _, method := range service.Methods { - comment := g.filterCommentString(method.Comments.Leading) + comment := g.filterCommentString(method.Comments.Leading, false) inputMessage := method.Input outputMessage := method.Output operationID := service.GoName + "_" + method.GoName + eopt := proto.GetExtension(method.Desc.Options(), v3.E_Openapiv3Operation) if eopt != nil && eopt != v3.E_Openapiv3Operation.InterfaceOf(v3.E_Openapiv3Operation.Zero()) { - opt := eopt.(*v3.Operation) - if opt.OperationId != "" { + if opt, ok := eopt.(*v3.Operation); ok && opt.OperationId != "" { operationID = opt.OperationId } } - xt := annotations.E_Http - extension := proto.GetExtension(method.Desc.Options(), xt) + var path string var methodName string var body string - if extension != nil && extension != xt.InterfaceOf(xt.Zero()) { - rule := extension.(*annotations.HttpRule) + + extHTTP := proto.GetExtension(method.Desc.Options(), api.E_Http) + if extHTTP != nil && extHTTP != api.E_Http.InterfaceOf(api.E_Http.Zero()) { + annotationsCount++ + + rule := extHTTP.(*api.HttpRule) body = rule.Body switch pattern := rule.Pattern.(type) { - case *annotations.HttpRule_Get: + case *api.HttpRule_Get: path = pattern.Get methodName = "GET" - case *annotations.HttpRule_Post: + case *api.HttpRule_Post: path = pattern.Post methodName = "POST" - case *annotations.HttpRule_Put: + case *api.HttpRule_Put: path = pattern.Put methodName = "PUT" - case *annotations.HttpRule_Delete: + case *api.HttpRule_Delete: path = pattern.Delete methodName = "DELETE" - case *annotations.HttpRule_Patch: + case *api.HttpRule_Patch: path = pattern.Patch methodName = "PATCH" - case *annotations.HttpRule_Custom: + case *api.HttpRule_Custom: path = "custom-unsupported" default: path = "unknown-unsupported" } } + if methodName != "" { + defaultHost := proto.GetExtension(service.Desc.Options(), api.E_DefaultHost).(string) + op, path2 := g.buildOperationV3( - file, method, operationID, comment, path, body, inputMessage, outputMessage) + file, method, operationID, service.GoName, comment, defaultHost, path, body, inputMessage, outputMessage) g.addOperationV3(d, op, path2, methodName) } } + + if annotationsCount > 0 { + comment := g.filterCommentString(service.Comments.Leading, false) + d.Tags = append(d.Tags, &v3.Tag{Name: service.GoName, Description: comment}) + } } } +func (g *openapiv3Generator) formatMessageRef(name string) string { + if g.naming == "proto" { + return name + } + + if len(name) > 1 { + return strings.ToUpper(name[0:1]) + name[1:] + } + + if len(name) == 1 { + return strings.ToLower(name) + } + + return name +} + +func (g *openapiv3Generator) getMessageName(message protoreflect.MessageDescriptor) string { + prefix := "" + parent := message.Parent() + if message != nil { + if _, ok := parent.(protoreflect.MessageDescriptor); ok { + prefix = string(parent.Name()) + "_" + prefix + } + } + + return prefix + string(message.Name()) +} + +func (g *openapiv3Generator) formatMessageName(message *protogen.Message) string { + name := g.getMessageName(message.Desc) + + if g.naming == "proto" { + return name + } + + if len(name) > 0 { + return strings.ToUpper(name[0:1]) + name[1:] + } + + return name +} + +func (g *openapiv3Generator) formatFieldName(field *protogen.Field) string { + log.Printf("proto %s json %s", string(field.Desc.Name()), field.Desc.JSONName()) + if g.naming == "proto" { + return string(field.Desc.Name()) + } + + return field.Desc.JSONName() +} + +func (g *openapiv3Generator) findField(name string, inMessage *protogen.Message) *protogen.Field { + for _, field := range inMessage.Fields { + if string(field.Desc.Name()) == name || string(field.Desc.JSONName()) == name { + return field + } + } + + return nil +} + +func (g *openapiv3Generator) findAndFormatFieldName(name string, inMessage *protogen.Message) string { + field := g.findField(name, inMessage) + if field != nil { + return g.formatFieldName(field) + } + + return name +} + +// Note that fields which are mapped to URL query parameters must have a primitive type +// or a repeated primitive type or a non-repeated message type. +// In the case of a repeated type, the parameter can be repeated in the URL as ...?param=A¶m=B. +// In the case of a message type, each field of the message is mapped to a separate parameter, +// such as ...?foo.a=A&foo.b=B&foo.c=C. +// +// maps, Struct and Empty can NOT be used +// messages can have any number of sub messages - including circular (e.g. sub.subsub.sub.subsub.id) + +// buildQueryParamsV3 extracts any valid query params, including sub and recursive messages +func (g *openapiv3Generator) buildQueryParamsV3(field *protogen.Field) []*v3.ParameterOrReference { + depths := map[string]int{} + return g._buildQueryParamsV3(field, depths) +} + +// depths are used to keep track of how many times a message's fields has been seen +func (g *openapiv3Generator) _buildQueryParamsV3(field *protogen.Field, depths map[string]int) []*v3.ParameterOrReference { + parameters := []*v3.ParameterOrReference{} + + queryFieldName := g.formatFieldName(field) + fieldDescription := g.filterCommentString(field.Comments.Leading, true) + + if field.Desc.IsMap() { + // Map types are not allowed in query parameteres + return parameters + } else if field.Desc.Kind() == protoreflect.MessageKind { + if field.Desc.IsList() { + // Only non-repeated message types are valid + return parameters + } + + // Represent field masks directly as strings (don't expand them). + if g.fullMessageTypeName(field.Desc.Message()) == ".google.protobuf.FieldMask" { + fieldSchema := g.schemaOrReferenceForField(field.Desc) + parameters = append(parameters, + &v3.ParameterOrReference{ + Oneof: &v3.ParameterOrReference_Parameter{ + Parameter: &v3.Parameter{ + Name: queryFieldName, + In: "query", + Description: fieldDescription, + Required: false, + Schema: fieldSchema, + }, + }, + }) + return parameters + } + log.Printf("DDDD %#+v", field.Message) + // Sub messages are allowed, even circular, as long as the final type is a primitive. + // Go through each of the sub message fields + for _, subField := range field.Message.Fields { + subFieldFullName := string(subField.Desc.FullName()) + seen, ok := depths[subFieldFullName] + if !ok { + depths[subFieldFullName] = 0 + } + if seen < g.circularDepth { + depths[subFieldFullName]++ + subParams := g._buildQueryParamsV3(subField, depths) + for _, subParam := range subParams { + if param, ok := subParam.Oneof.(*v3.ParameterOrReference_Parameter); ok { + param.Parameter.Name = queryFieldName + "." + param.Parameter.Name + parameters = append(parameters, subParam) + } + } + } + } + + } else if field.Desc.Kind() != protoreflect.GroupKind { + // schemaOrReferenceForField also handles array types + fieldSchema := g.schemaOrReferenceForField(field.Desc) + + parameters = append(parameters, + &v3.ParameterOrReference{ + Oneof: &v3.ParameterOrReference_Parameter{ + Parameter: &v3.Parameter{ + Name: queryFieldName, + In: "query", + Description: fieldDescription, + Required: false, + Schema: fieldSchema, + }, + }, + }) + } + + return parameters +} + // buildOperationV3 constructs an operation for a set of values. func (g *openapiv3Generator) buildOperationV3( file *protogen.File, method *protogen.Method, operationID string, + tagName string, description string, + defaultHost string, path string, bodyField string, inputMessage *protogen.Message, @@ -217,6 +479,7 @@ func (g *openapiv3Generator) buildOperationV3( } // Initialize the list of operation parameters. parameters := []*v3.ParameterOrReference{} + // Build a list of header parameters. eopt := proto.GetExtension(method.Desc.Options(), v3.E_Openapiv3Operation) if eopt != nil && eopt != v3.E_Openapiv3Operation.InterfaceOf(v3.E_Openapiv3Operation.Zero()) { @@ -233,86 +496,117 @@ func (g *openapiv3Generator) buildOperationV3( sparameters[parameter.Name] = struct{}{} } - // Build a list of path parameters. - pathParameters := make([]string, 0) - if matches := g.namePattern.FindStringSubmatch(path); matches != nil { + // Find simple path parameters like {id} + if allMatches := g.pathPattern.FindAllStringSubmatch(path, -1); allMatches != nil { + for _, matches := range allMatches { + // Add the value to the list of covered parameters. + coveredParameters = append(coveredParameters, matches[1]) + pathParameter := g.findAndFormatFieldName(matches[1], inputMessage) + path = strings.Replace(path, matches[1], pathParameter, 1) + + // Add the path parameters to the operation parameters. + var fieldSchema *v3.SchemaOrReference + + var fieldDescription string + field := g.findField(pathParameter, inputMessage) + if field != nil { + fieldSchema = g.schemaOrReferenceForField(field.Desc) + fieldDescription = g.filterCommentString(field.Comments.Leading, true) + } else { + // If field dooes not exist, it is safe to set it to string, as it is ignored downstream + fieldSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{ + Type: "string", + }, + }, + } + } + + parameters = append(parameters, + &v3.ParameterOrReference{ + Oneof: &v3.ParameterOrReference_Parameter{ + Parameter: &v3.Parameter{ + Name: pathParameter, + In: "path", + Description: fieldDescription, + Required: true, + Schema: fieldSchema, + }, + }, + }) + } + } + + // Find named path parameters like {name=shelves/*} + if matches := g.namedPathPattern.FindStringSubmatch(path); matches != nil { + // Build a list of named path parameters. + namedPathParameters := make([]string, 0) + // Add the "name=" "name" value to the list of covered parameters. coveredParameters = append(coveredParameters, matches[1]) // Convert the path from the starred form to use named path parameters. starredPath := matches[2] parts := strings.Split(starredPath, "/") // The starred path is assumed to be in the form "things/*/otherthings/*". - // We want to convert it to "things/{thing}/otherthings/{otherthing}". + // We want to convert it to "things/{thingsId}/otherthings/{otherthingsId}". for i := 0; i < len(parts)-1; i += 2 { section := parts[i] - parameter := singular(section) - parts[i+1] = "{" + parameter + "}" - pathParameters = append(pathParameters, parameter) + namedPathParameter := g.findAndFormatFieldName(section, inputMessage) + namedPathParameter = singular(namedPathParameter) + parts[i+1] = "{" + namedPathParameter + "}" + namedPathParameters = append(namedPathParameters, namedPathParameter) } // Rewrite the path to use the path parameters. newPath := strings.Join(parts, "/") path = strings.Replace(path, matches[0], newPath, 1) - } - // Add the path parameters to the operation parameters. - for _, pathParameter := range pathParameters { - if _, ok := sparameters[pathParameter]; ok { - continue - } - parameters = append(parameters, - &v3.ParameterOrReference{ - Oneof: &v3.ParameterOrReference_Parameter{ - Parameter: &v3.Parameter{ - Name: pathParameter, - In: "path", - Required: true, - Description: "The " + pathParameter + " id.", - Schema: &v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Schema{ - Schema: &v3.Schema{ - Type: "string", - }, - }, - }, - }, - }, - }) - } - // Add any unhandled fields in the request message as query parameters. - if bodyField != "*" { - for _, field := range inputMessage.Fields { - fieldName := string(field.Desc.Name()) - if !contains(coveredParameters, fieldName) { - if _, ok := sparameters[fieldName]; ok { - continue - } - // Get the field description from the comments. - fieldDescription := g.filterCommentString(field.Comments.Leading) - parameters = append(parameters, - &v3.ParameterOrReference{ - Oneof: &v3.ParameterOrReference_Parameter{ - Parameter: &v3.Parameter{ - Name: fieldName, - In: "query", - Description: fieldDescription, - Required: false, - Schema: &v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Schema{ - Schema: &v3.Schema{ - Type: "string", - }, + // Add the named path parameters to the operation parameters. + for _, namedPathParameter := range namedPathParameters { + if _, ok := sparameters[namedPathParameter]; ok { + continue + } + parameters = append(parameters, + &v3.ParameterOrReference{ + Oneof: &v3.ParameterOrReference_Parameter{ + Parameter: &v3.Parameter{ + Name: namedPathParameter, + In: "path", + Required: true, + Description: "The " + namedPathParameter + " id.", + Schema: &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{ + Type: "string", }, }, }, }, - }) + }, + }) + } + } + + // Add any unhandled fields in the request message as query parameters. + if bodyField != "*" { + for _, field := range inputMessage.Fields { + fieldName := string(field.Desc.Name()) + log.Printf("bodyfield %v coveredParameters %#+v fieldName %v", bodyField, coveredParameters, fieldName) + if !contains(coveredParameters, fieldName) && fieldName != bodyField { + log.Printf("append!!! field %#+v", field) + fieldParams := g.buildQueryParamsV3(field) + log.Printf("add param %#+v field %#+v", fieldParams, field) + parameters = append(parameters, fieldParams...) + } else { + log.Printf("not append") } } } + // Create the response. responses := &v3.Responses{ ResponseOrReference: []*v3.NamedResponseOrReference{ - &v3.NamedResponseOrReference{ + { Name: "200", Value: &v3.ResponseOrReference{ Oneof: &v3.ResponseOrReference_Response{ @@ -325,20 +619,31 @@ func (g *openapiv3Generator) buildOperationV3( }, }, } + // Create the operation. op := &v3.Operation{ - Summary: description, + Tags: []string{tagName}, + Description: description, OperationId: operationID, Parameters: parameters, Responses: responses, } + + if defaultHost != "" { + hostURL, err := url.Parse(defaultHost) + if err == nil { + hostURL.Scheme = "https" + op.Servers = append(op.Servers, &v3.Server{Url: hostURL.String()}) + } + } + // If a body field is specified, we need to pass a message as the request body. if bodyField != "" { var bodyFieldScalarTypeName string var bodyFieldMessageTypeName string if bodyField == "*" { // Pass the entire request message as the request body. - bodyFieldMessageTypeName = g.fullMessageTypeName(inputMessage) + bodyFieldMessageTypeName = g.fullMessageTypeName(inputMessage.Desc) } else { // If body refers to a message field, use that type. for _, field := range inputMessage.Fields { @@ -347,7 +652,7 @@ func (g *openapiv3Generator) buildOperationV3( case protoreflect.StringKind: bodyFieldScalarTypeName = "string" case protoreflect.MessageKind: - bodyFieldMessageTypeName = g.fullMessageTypeName(field.Message) + bodyFieldMessageTypeName = g.fullMessageTypeName(field.Message.Desc) default: log.Printf("unsupported field type %+v", field.Desc) } @@ -365,12 +670,25 @@ func (g *openapiv3Generator) buildOperationV3( }, } } else if bodyFieldMessageTypeName != "" { - requestSchema = &v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Reference{ - Reference: &v3.Reference{ - XRef: g.schemaReferenceForTypeName(bodyFieldMessageTypeName), + switch bodyFieldMessageTypeName { + case ".google.protobuf.Empty": + fallthrough + case ".google.protobuf.Struct": + requestSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{ + Type: "object", + }, }, - }, + } + default: + requestSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Reference{ + Reference: &v3.Reference{ + XRef: g.schemaReferenceForTypeName(bodyFieldMessageTypeName), + }, + }, + } } } op.RequestBody = &v3.RequestBodyOrReference{ @@ -379,7 +697,7 @@ func (g *openapiv3Generator) buildOperationV3( Required: true, Content: &v3.MediaTypes{ AdditionalProperties: []*v3.NamedMediaType{ - &v3.NamedMediaType{ + { Name: "application/json", Value: &v3.MediaType{ Schema: requestSchema, @@ -410,15 +728,15 @@ func (g *openapiv3Generator) addOperationV3(d *v3.Document, op *v3.Operation, pa } // Set the operation on the specified method. switch methodName { - case http.MethodGet: + case "GET": selectedPathItem.Value.Get = op - case http.MethodPost: + case "POST": selectedPathItem.Value.Post = op - case http.MethodPut: + case "PUT": selectedPathItem.Value.Put = op - case http.MethodDelete: + case "DELETE": selectedPathItem.Value.Delete = op - case http.MethodPatch: + case "PATCH": selectedPathItem.Value.Patch = op } } @@ -430,47 +748,29 @@ func (g *openapiv3Generator) schemaReferenceForTypeName(typeName string) string } parts := strings.Split(typeName, ".") lastPart := parts[len(parts)-1] - return "#/components/schemas/" + lastPart -} - -// itemsItemForTypeName is a helper constructor. -func (g *openapiv3Generator) itemsItemForTypeName(typeName string) *v3.ItemsItem { - return &v3.ItemsItem{SchemaOrReference: []*v3.SchemaOrReference{&v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Schema{ - Schema: &v3.Schema{ - Type: typeName, - }, - }, - }}} -} - -// itemsItemForReference is a helper constructor. -func (g *openapiv3Generator) itemsItemForReference(xref string) *v3.ItemsItem { - return &v3.ItemsItem{SchemaOrReference: []*v3.SchemaOrReference{&v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Reference{ - Reference: &v3.Reference{ - XRef: xref, - }, - }, - }}} + return "#/components/schemas/" + g.formatMessageRef(lastPart) } // fullMessageTypeName builds the full type name of a message. -func (g *openapiv3Generator) fullMessageTypeName(message *protogen.Message) string { - return "." + string(message.Desc.ParentFile().Package()) + "." + string(message.Desc.Name()) +func (g *openapiv3Generator) fullMessageTypeName(message protoreflect.MessageDescriptor) string { + name := g.getMessageName(message) + return "." + string(message.ParentFile().Package()) + "." + name } func (g *openapiv3Generator) responseContentForMessage(outputMessage *protogen.Message) *v3.MediaTypes { - typeName := g.fullMessageTypeName(outputMessage) + typeName := g.fullMessageTypeName(outputMessage.Desc) if typeName == ".google.protobuf.Empty" { return &v3.MediaTypes{} } + if typeName == ".google.protobuf.Struct" { + return &v3.MediaTypes{} + } - if typeName == ".google.api.HttpBody" || typeName == ".micro.codec.Frame" { + if typeName == ".google.api.HttpBody" || typeName == ".micro.codec.Frame" || typeName == ".micro.api.HttpBody" { return &v3.MediaTypes{ AdditionalProperties: []*v3.NamedMediaType{ - &v3.NamedMediaType{ + { Name: "application/octet-stream", Value: &v3.MediaType{}, }, @@ -480,13 +780,13 @@ func (g *openapiv3Generator) responseContentForMessage(outputMessage *protogen.M return &v3.MediaTypes{ AdditionalProperties: []*v3.NamedMediaType{ - &v3.NamedMediaType{ + { Name: "application/json", Value: &v3.MediaType{ Schema: &v3.SchemaOrReference{ Oneof: &v3.SchemaOrReference_Reference{ Reference: &v3.Reference{ - XRef: g.schemaReferenceForTypeName(g.fullMessageTypeName(outputMessage)), + XRef: g.schemaReferenceForTypeName(g.fullMessageTypeName(outputMessage.Desc)), }, }, }, @@ -496,11 +796,166 @@ func (g *openapiv3Generator) responseContentForMessage(outputMessage *protogen.M } } +func (g *openapiv3Generator) schemaOrReferenceForType(typeName string) *v3.SchemaOrReference { + switch typeName { + + case ".google.protobuf.Timestamp": + // Timestamps are serialized as strings + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "string", Format: "RFC3339"}, + }, + } + + case ".google.type.Date": + // Dates are serialized as strings + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "string", Format: "date"}, + }, + } + + case ".google.type.DateTime": + // DateTimes are serialized as strings + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "string", Format: "date-time"}, + }, + } + + case ".google.protobuf.FieldMask": + // Field masks are serialized as strings + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "string", Format: "field-mask"}, + }, + } + + case ".google.protobuf.Struct": + // Struct is equivalent to a JSON object + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "object"}, + }, + } + + case ".google.protobuf.Empty": + // Empty is close to JSON undefined than null, so ignore this field + return nil //&v3.SchemaOrReference{Oneof: &v3.SchemaOrReference_Schema{Schema: &v3.Schema{Type: "null"}}} + + default: + ref := g.schemaReferenceForTypeName(typeName) + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Reference{ + Reference: &v3.Reference{XRef: ref}, + }, + } + } +} + +func (g *openapiv3Generator) schemaOrReferenceForField(field protoreflect.FieldDescriptor) *v3.SchemaOrReference { + if field.IsMap() { + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{ + Type: "object", + AdditionalProperties: &v3.AdditionalPropertiesItem{ + Oneof: &v3.AdditionalPropertiesItem_SchemaOrReference{ + SchemaOrReference: g.schemaOrReferenceForField(field.MapValue()), + }, + }, + }, + }, + } + } + + var kindSchema *v3.SchemaOrReference + + kind := field.Kind() + + switch kind { + + case protoreflect.MessageKind: + typeName := g.fullMessageTypeName(field.Message()) + kindSchema = g.schemaOrReferenceForType(typeName) + if kindSchema == nil { + return nil + } + + case protoreflect.StringKind: + kindSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "string"}, + }, + } + + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind, + protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind, + protoreflect.Sfixed32Kind, protoreflect.Fixed32Kind, protoreflect.Sfixed64Kind, + protoreflect.Fixed64Kind: + kindSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "integer", Format: kind.String()}, + }, + } + + case protoreflect.EnumKind: + kindSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "integer", Format: "enum"}, + }, + } + + case protoreflect.BoolKind: + kindSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "boolean"}, + }, + } + + case protoreflect.FloatKind, protoreflect.DoubleKind: + kindSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "number", Format: kind.String()}, + }, + } + + case protoreflect.BytesKind: + kindSchema = &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{Type: "string", Format: "bytes"}, + }, + } + + default: + log.Printf("(TODO) Unsupported field type: %+v", g.fullMessageTypeName(field.Message())) + } + + if field.IsList() { + return &v3.SchemaOrReference{ + Oneof: &v3.SchemaOrReference_Schema{ + Schema: &v3.Schema{ + Type: "array", + Items: &v3.ItemsItem{SchemaOrReference: []*v3.SchemaOrReference{kindSchema}}, + }, + }, + } + } + + return kindSchema +} + // addSchemasToDocumentV3 adds info from one file descriptor. -func (g *openapiv3Generator) addSchemasToDocumentV3(d *v3.Document, file *protogen.File) { +func (g *openapiv3Generator) addSchemasToDocumentV3(d *v3.Document, messages []*protogen.Message) { // For each message, generate a definition. - for _, message := range file.Messages { - typeName := g.fullMessageTypeName(message) + for _, message := range messages { + + if message.Messages != nil { + g.addSchemasToDocumentV3(d, message.Messages) + } + + typeName := g.fullMessageTypeName(message.Desc) + // Only generate this if we need it and haven't already generated it. if !contains(g.requiredSchemas, typeName) || contains(g.generatedSchemas, typeName) { @@ -508,156 +963,70 @@ func (g *openapiv3Generator) addSchemasToDocumentV3(d *v3.Document, file *protog } g.generatedSchemas = append(g.generatedSchemas, typeName) // Get the message description from the comments. - messageDescription := g.filterCommentString(message.Comments.Leading) + messageDescription := g.filterCommentString(message.Comments.Leading, true) // Build an array holding the fields of the message. definitionProperties := &v3.Properties{ AdditionalProperties: make([]*v3.NamedSchemaOrReference, 0), } + var required []string for _, field := range message.Fields { - // Check the field annotations to see if this is a readonly field. + // Check the field annotations to see if this is a readonly or writeonly field. + inputOnly := false outputOnly := false - extension := proto.GetExtension(field.Desc.Options(), annotations.E_FieldBehavior) + extension := proto.GetExtension(field.Desc.Options(), api.E_FieldBehavior) if extension != nil { switch v := extension.(type) { - case []annotations.FieldBehavior: + case []api.FieldBehavior: for _, vv := range v { - if vv == annotations.FieldBehavior_OUTPUT_ONLY { + switch vv { + case api.FieldBehavior_OUTPUT_ONLY: outputOnly = true + case api.FieldBehavior_INPUT_ONLY: + inputOnly = true + case api.FieldBehavior_REQUIRED: + required = append(required, g.formatFieldName(field)) } } default: log.Printf("unsupported extension type %T", extension) } } - // Get the field description from the comments. - fieldDescription := g.filterCommentString(field.Comments.Leading) + // The field is either described by a reference or a schema. - XRef := "" - fieldSchema := &v3.Schema{ - Description: fieldDescription, + fieldSchema := g.schemaOrReferenceForField(field.Desc) + if fieldSchema == nil { + continue } - if outputOnly { - fieldSchema.ReadOnly = true - } - if field.Desc.IsList() { - fieldSchema.Type = "array" - switch field.Desc.Kind() { - case protoreflect.MessageKind: - fieldSchema.Items = g.itemsItemForReference( - g.schemaReferenceForTypeName( - g.fullMessageTypeName(field.Message))) - case protoreflect.StringKind: - fieldSchema.Items = g.itemsItemForTypeName("string") - case protoreflect.Int32Kind, - protoreflect.Sint32Kind, - protoreflect.Uint32Kind, - protoreflect.Int64Kind, - protoreflect.Sint64Kind, - protoreflect.Uint64Kind, - protoreflect.Sfixed32Kind, - protoreflect.Fixed32Kind, - protoreflect.Sfixed64Kind, - protoreflect.Fixed64Kind: - fieldSchema.Items = g.itemsItemForTypeName("integer") - case protoreflect.EnumKind: - fieldSchema.Items = g.itemsItemForTypeName("integer") - case protoreflect.BoolKind: - fieldSchema.Items = g.itemsItemForTypeName("boolean") - case protoreflect.FloatKind, protoreflect.DoubleKind: - fieldSchema.Items = g.itemsItemForTypeName("number") - case protoreflect.BytesKind: - fieldSchema.Items = g.itemsItemForTypeName("string") - default: - log.Printf("(TODO) Unsupported array type: %+v", g.fullMessageTypeName(field.Message)) + + if schema, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Schema); ok { + // Get the field description from the comments. + schema.Schema.Description = g.filterCommentString(field.Comments.Leading, true) + if outputOnly { + schema.Schema.ReadOnly = true } - } else if field.Desc.IsMap() && - field.Desc.MapKey().Kind() == protoreflect.StringKind && - field.Desc.MapValue().Kind() == protoreflect.StringKind { - fieldSchema.Type = "object" - } else { - k := field.Desc.Kind() - switch k { - case protoreflect.MessageKind: - typeName := g.fullMessageTypeName(field.Message) - switch typeName { - case ".google.protobuf.Timestamp": - // Timestamps are serialized as strings - fieldSchema.Type = "string" - fieldSchema.Format = "RFC3339" - case ".google.type.Date": - // Dates are serialized as strings - fieldSchema.Type = "string" - fieldSchema.Format = "date" - case ".google.type.DateTime": - // DateTimes are serialized as strings - fieldSchema.Type = "string" - fieldSchema.Format = "date-time" - default: - // The field is described by a reference. - XRef = g.schemaReferenceForTypeName(typeName) - } - case protoreflect.StringKind: - fieldSchema.Type = "string" - case protoreflect.Int32Kind, - protoreflect.Sint32Kind, - protoreflect.Uint32Kind, - protoreflect.Int64Kind, - protoreflect.Sint64Kind, - protoreflect.Uint64Kind, - protoreflect.Sfixed32Kind, - protoreflect.Fixed32Kind, - protoreflect.Sfixed64Kind, - protoreflect.Fixed64Kind: - fieldSchema.Type = "integer" - fieldSchema.Format = k.String() - case protoreflect.EnumKind: - fieldSchema.Type = "integer" - fieldSchema.Format = "enum" - case protoreflect.BoolKind: - fieldSchema.Type = "boolean" - case protoreflect.FloatKind, protoreflect.DoubleKind: - fieldSchema.Type = "number" - fieldSchema.Format = k.String() - case protoreflect.BytesKind: - fieldSchema.Type = "string" - fieldSchema.Format = "bytes" - default: - log.Printf("(TODO) Unsupported field type: %+v", g.fullMessageTypeName(field.Message)) - } - } - var value *v3.SchemaOrReference - if XRef != "" { - value = &v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Reference{ - Reference: &v3.Reference{ - XRef: XRef, - }, - }, - } - } else { - value = &v3.SchemaOrReference{ - Oneof: &v3.SchemaOrReference_Schema{ - Schema: fieldSchema, - }, + if inputOnly { + schema.Schema.WriteOnly = true } } + definitionProperties.AdditionalProperties = append( definitionProperties.AdditionalProperties, &v3.NamedSchemaOrReference{ - Name: string(field.Desc.Name()), - Value: value, + Name: g.formatFieldName(field), + Value: fieldSchema, }, ) } // Add the schema to the components.schema list. d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, &v3.NamedSchemaOrReference{ - Name: string(message.Desc.Name()), + Name: g.formatMessageName(message), Value: &v3.SchemaOrReference{ Oneof: &v3.SchemaOrReference_Schema{ Schema: &v3.Schema{ Description: messageDescription, Properties: definitionProperties, + Required: required, }, }, }, @@ -676,6 +1045,14 @@ func contains(s []string, e string) bool { return false } +// appendUniuqe appends a string, to a string slice, if the string is not already in the slice +func appendUniuqe(s []string, e string) []string { + if !contains(s, e) { + return append(s, e) + } + return s +} + // singular produces the singular form of a collection name. func singular(plural string) string { if strings.HasSuffix(plural, "ves") { diff --git a/util.go b/util.go index dbaa3ab..9443274 100644 --- a/util.go +++ b/util.go @@ -62,9 +62,15 @@ func generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protog if strings.HasPrefix(ref, "."+string(service.Desc.ParentFile().Package())+".") { ref = strings.TrimPrefix(ref, "."+string(service.Desc.ParentFile().Package())+".") } - if ref == "micro.codec.Frame" || ref == ".micro.codec.Frame" { + if ref[0] == '.' { + ref = ref[1:] + } + switch ref { + case "micro.codec.Frame": gfile.P(`errmap["`, rsp.Name, `"] = &`, microCodecPackage.Ident("Frame"), "{}") - } else { + case "micro.errors.Error": + gfile.P(`errmap["`, rsp.Name, `"] = &`, microErrorsPackage.Ident("Error"), "{}") + default: gfile.P(`errmap["`, rsp.Name, `"] = &`, ref, "{}") } } @@ -165,7 +171,7 @@ func generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protog } if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { - gfile.P("func (s *", unexport(serviceName), "Client", method.GoName, ") RecvAndClose() (*", gfile.QualifiedGoIdent(method.Output.GoIdent), ", error) {") + gfile.P("func (s *", unexport(serviceName), "Client", method.GoName, ") CloseAndRecv() (*", gfile.QualifiedGoIdent(method.Output.GoIdent), ", error) {") gfile.P("msg := &", gfile.QualifiedGoIdent(method.Output.GoIdent), "{}") gfile.P("err := s.RecvMsg(msg)") gfile.P("if err == nil {") @@ -183,6 +189,10 @@ func generateServiceClientMethods(gfile *protogen.GeneratedFile, service *protog gfile.P("return s.stream.Close()") gfile.P("}") gfile.P() + gfile.P("func (s *", unexport(serviceName), "Client", method.GoName, ") CloseSend() error {") + gfile.P("return s.stream.CloseSend()") + gfile.P("}") + gfile.P() gfile.P("func (s *", unexport(serviceName), "Client", method.GoName, ") Context() ", contextPackage.Ident("Context"), " {") gfile.P("return s.stream.Context()") gfile.P("}") @@ -478,7 +488,8 @@ func generateServiceClientStreamInterface(gfile *protogen.GeneratedFile, service gfile.P("SendMsg(msg interface{}) error") gfile.P("RecvMsg(msg interface{}) error") if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { - gfile.P("RecvAndClose() (*", gfile.QualifiedGoIdent(method.Output.GoIdent), ", error)") + gfile.P("CloseAndRecv() (*", gfile.QualifiedGoIdent(method.Output.GoIdent), ", error)") + gfile.P("CloseSend() () error") } gfile.P("Close() error") if method.Desc.IsStreamingClient() { @@ -505,6 +516,7 @@ func generateServiceServerStreamInterface(gfile *protogen.GeneratedFile, service gfile.P("RecvMsg(msg interface{}) error") if method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { gfile.P("SendAndClose(msg *", gfile.QualifiedGoIdent(method.Output.GoIdent), ") error") + gfile.P("CloseSend() error") } gfile.P("Close() error") if method.Desc.IsStreamingClient() { diff --git a/variables.go b/variables.go index c7e7397..0cf425b 100644 --- a/variables.go +++ b/variables.go @@ -17,6 +17,7 @@ var ( microClientHttpPackage = protogen.GoImportPath("go.unistack.org/micro-client-http/v3") microServerHttpPackage = protogen.GoImportPath("go.unistack.org/micro-server-http/v3") microCodecPackage = protogen.GoImportPath("go.unistack.org/micro/v3/codec") + microErrorsPackage = protogen.GoImportPath("go.unistack.org/micro/v3/errors") timePackage = protogen.GoImportPath("time") deprecationComment = "// Deprecated: Do not use." versionComment = "v3.5.3"