Skip to content

Commit

Permalink
Merge pull request #1856 from josephschorr/lsp-improvements
Browse files Browse the repository at this point in the history
LSP improvements
  • Loading branch information
josephschorr committed Apr 5, 2024
2 parents d73b0ac + adb641e commit d3150c2
Show file tree
Hide file tree
Showing 18 changed files with 534 additions and 131 deletions.
127 changes: 106 additions & 21 deletions internal/lsp/handlers.go
Expand Up @@ -13,7 +13,9 @@ import (
log "github.com/authzed/spicedb/internal/logging"
"github.com/authzed/spicedb/pkg/development"
developerv1 "github.com/authzed/spicedb/pkg/proto/developer/v1"
"github.com/authzed/spicedb/pkg/schemadsl/compiler"
"github.com/authzed/spicedb/pkg/schemadsl/generator"
"github.com/authzed/spicedb/pkg/schemadsl/input"
)

func (s *Server) textDocDiagnostic(ctx context.Context, r *jsonrpc2.Request) (FullDocumentDiagnosticReport, error) {
Expand Down Expand Up @@ -45,7 +47,7 @@ func (s *Server) textDocDiagnostic(ctx context.Context, r *jsonrpc2.Request) (Fu

func (s *Server) computeDiagnostics(ctx context.Context, uri lsp.DocumentURI) ([]lsp.Diagnostic, error) {
diagnostics := make([]lsp.Diagnostic, 0) // Important: must not be nil for the consumer on the client side
if err := s.withFiles(func(files *persistent.Map[lsp.DocumentURI, string]) error {
if err := s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error {
file, ok := files.Get(uri)
if !ok {
log.Warn().
Expand All @@ -56,7 +58,7 @@ func (s *Server) computeDiagnostics(ctx context.Context, uri lsp.DocumentURI) ([
}

_, devErrs, err := development.NewDevContext(ctx, &developerv1.RequestContext{
Schema: file,
Schema: file.contents,
Relationships: nil,
})
if err != nil {
Expand Down Expand Up @@ -88,7 +90,7 @@ func (s *Server) textDocDidChange(ctx context.Context, r *jsonrpc2.Request, conn
return nil, err
}

s.files.Set(params.TextDocument.URI, params.ContentChanges[0].Text, nil)
s.files.Set(params.TextDocument.URI, trackedFile{params.ContentChanges[0].Text, nil}, nil)

if err := s.publishDiagnosticsIfNecessary(ctx, conn, params.TextDocument.URI); err != nil {
return nil, err
Expand All @@ -115,7 +117,7 @@ func (s *Server) textDocDidOpen(ctx context.Context, r *jsonrpc2.Request, conn *

uri := params.TextDocument.URI
contents := params.TextDocument.Text
s.files.Set(uri, contents, nil)
s.files.Set(uri, trackedFile{contents, nil}, nil)

if err := s.publishDiagnosticsIfNecessary(ctx, conn, uri); err != nil {
return nil, err
Expand Down Expand Up @@ -150,36 +152,113 @@ func (s *Server) publishDiagnosticsIfNecessary(ctx context.Context, conn *jsonrp
})
}

func (s *Server) textDocFormat(ctx context.Context, r *jsonrpc2.Request) ([]lsp.TextEdit, error) {
params, err := unmarshalParams[lsp.DocumentFormattingParams](r)
func (s *Server) getCompiledContents(path lsp.DocumentURI, files *persistent.Map[lsp.DocumentURI, trackedFile]) (*compiler.CompiledSchema, error) {
file, ok := files.Get(path)
if !ok {
return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"}
}

compiled := file.parsed
if compiled != nil {
return compiled, nil
}

justCompiled, derr, err := development.CompileSchema(file.contents)
if err != nil || derr != nil {
return nil, err
}

files.Set(path, trackedFile{file.contents, justCompiled}, nil)
return justCompiled, nil
}

func (s *Server) textDocHover(_ context.Context, r *jsonrpc2.Request) (*Hover, error) {
params, err := unmarshalParams[lsp.TextDocumentPositionParams](r)
if err != nil {
return nil, err
}

var formatted string
err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, string]) error {
file, ok := files.Get(params.TextDocument.URI)
if !ok {
log.Warn().
Str("uri", string(params.TextDocument.URI)).
Msg("file not found for formatting")
var hoverContents *Hover
err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error {
compiled, err := s.getCompiledContents(params.TextDocument.URI, files)
if err != nil {
return err
}

return &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"}
resolver, err := development.NewResolver(compiled)
if err != nil {
return err
}

dctx, devErrs, err := development.NewDevContext(ctx, &developerv1.RequestContext{
Schema: file,
Relationships: nil,
})
position := input.Position{
LineNumber: params.Position.Line,
ColumnPosition: params.Position.Character,
}

resolved, err := resolver.ReferenceAtPosition(input.Source("schema"), position)
if err != nil {
return err
}

if len(devErrs.GetInputErrors()) > 0 {
if resolved == nil {
return nil
}

formattedSchema, _, err := generator.GenerateSchema(dctx.CompiledSchema.OrderedDefinitions)
var lspRange *lsp.Range
if resolved.TargetPosition != nil {
lspRange = &lsp.Range{
Start: lsp.Position{
Line: resolved.TargetPosition.LineNumber,
Character: resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset,
},
End: lsp.Position{
Line: resolved.TargetPosition.LineNumber,
Character: resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset + len(resolved.Text),
},
}
}

if resolved.TargetSourceCode != "" {
hoverContents = &Hover{
Contents: MarkupContent{
Language: "spicedb",
Value: resolved.TargetSourceCode,
},
Range: lspRange,
}
} else {
hoverContents = &Hover{
Contents: MarkupContent{
Kind: "markdown",
Value: resolved.ReferenceMarkdown,
},
Range: lspRange,
}
}

return nil
})
if err != nil {
return nil, err
}

return hoverContents, nil
}

func (s *Server) textDocFormat(_ context.Context, r *jsonrpc2.Request) ([]lsp.TextEdit, error) {
params, err := unmarshalParams[lsp.DocumentFormattingParams](r)
if err != nil {
return nil, err
}

var formatted string
err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error {
compiled, err := s.getCompiledContents(params.TextDocument.URI, files)
if err != nil {
return err
}

formattedSchema, _, err := generator.GenerateSchema(compiled.OrderedDefinitions)
if err != nil {
return err
}
Expand Down Expand Up @@ -236,6 +315,7 @@ func (s *Server) initialize(_ context.Context, r *jsonrpc2.Request) (any, error)
CompletionProvider: &lsp.CompletionOptions{TriggerCharacters: []string{"."}},
DocumentFormattingProvider: true,
DiagnosticProvider: &DiagnosticOptions{Identifier: "spicedb", InterFileDependencies: false, WorkspaceDiagnostics: false},
HoverProvider: true,
},
}, nil
}
Expand All @@ -247,7 +327,12 @@ func (s *Server) shutdown() error {
return nil
}

func (s *Server) withFiles(fn func(*persistent.Map[lsp.DocumentURI, string]) error) error {
type trackedFile struct {
contents string
parsed *compiler.CompiledSchema
}

func (s *Server) withFiles(fn func(*persistent.Map[lsp.DocumentURI, trackedFile]) error) error {
clone := s.files.Clone()
defer clone.Destroy()
return fn(clone)
Expand Down
6 changes: 4 additions & 2 deletions internal/lsp/lsp.go
Expand Up @@ -25,7 +25,7 @@ const (

// Server is a Language Server Protocol server for SpiceDB schema development.
type Server struct {
files *persistent.Map[lsp.DocumentURI, string]
files *persistent.Map[lsp.DocumentURI, trackedFile]
state serverState

requestsDiagnostics bool
Expand All @@ -35,7 +35,7 @@ type Server struct {
func NewServer() *Server {
return &Server{
state: serverStateNotInitialized,
files: persistent.NewMap[lsp.DocumentURI, string](func(x, y lsp.DocumentURI) bool {
files: persistent.NewMap[lsp.DocumentURI, trackedFile](func(x, y lsp.DocumentURI) bool {
return string(x) < string(y)
}),
}
Expand Down Expand Up @@ -86,6 +86,8 @@ func (s *Server) handle(ctx context.Context, conn *jsonrpc2.Conn, r *jsonrpc2.Re
result, err = s.textDocDiagnostic(ctx, r)
case "textDocument/formatting":
result, err = s.textDocFormat(ctx, r)
case "textDocument/hover":
result, err = s.textDocHover(ctx, r)
default:
log.Ctx(ctx).Warn().
Str("method", r.Method).
Expand Down
35 changes: 32 additions & 3 deletions internal/lsp/lsp_test.go
Expand Up @@ -28,13 +28,13 @@ func TestDocumentChange(t *testing.T) {

contents, ok := tester.server.files.Get("file:///test")
require.True(t, ok)
require.Equal(t, "test", contents)
require.Equal(t, "test", contents.contents)

tester.setFileContents("file:///test", "test2")

contents, ok = tester.server.files.Get("file:///test")
require.True(t, ok)
require.Equal(t, "test2", contents)
require.Equal(t, "test2", contents.contents)
}

func TestDocumentNoDiagnostics(t *testing.T) {
Expand Down Expand Up @@ -138,7 +138,7 @@ func TestDocumentOpenedClosed(t *testing.T) {

contents, ok := tester.server.files.Get(lsp.DocumentURI("file:///test"))
require.True(t, ok)
require.Equal(t, "definition user{}", contents)
require.Equal(t, "definition user{}", contents.contents)

sendAndReceive[any](tester, "textDocument/didClose", lsp.DidCloseTextDocumentParams{
TextDocument: lsp.TextDocumentIdentifier{
Expand All @@ -149,3 +149,32 @@ func TestDocumentOpenedClosed(t *testing.T) {
_, ok = tester.server.files.Get(lsp.DocumentURI("file:///test"))
require.False(t, ok)
}

func TestDocumentHover(t *testing.T) {
tester := newLSPTester(t)
tester.initialize()

sendAndReceive[any](tester, "textDocument/didOpen", lsp.DidOpenTextDocumentParams{
TextDocument: lsp.TextDocumentItem{
URI: lsp.DocumentURI("file:///test"),
LanguageID: "test",
Version: 1,
Text: `definition user {}
definition resource {
relation viewer: user
}
`,
},
})

resp, _ := sendAndReceive[Hover](tester, "textDocument/hover", lsp.TextDocumentPositionParams{
TextDocument: lsp.TextDocumentIdentifier{
URI: lsp.DocumentURI("file:///test"),
},
Position: lsp.Position{Line: 3, Character: 18},
})

require.Equal(t, "definition user {}", resp.Contents.Value)
require.Equal(t, "spicedb", resp.Contents.Language)
}
16 changes: 15 additions & 1 deletion internal/lsp/lspdefs.go
@@ -1,6 +1,8 @@
package lsp

import baselsp "github.com/sourcegraph/go-lsp"
import (
baselsp "github.com/sourcegraph/go-lsp"
)

type InitializeResult struct {
Capabilities ServerCapabilities `json:"capabilities,omitempty"`
Expand All @@ -11,6 +13,7 @@ type ServerCapabilities struct {
CompletionProvider *baselsp.CompletionOptions `json:"completionProvider,omitempty"`
DocumentFormattingProvider bool `json:"documentFormattingProvider,omitempty"`
DiagnosticProvider *DiagnosticOptions `json:"diagnosticProvider,omitempty"`
HoverProvider bool `json:"hoverProvider,omitempty"`
}

type DiagnosticOptions struct {
Expand Down Expand Up @@ -58,3 +61,14 @@ type DiagnosticWorkspaceClientCapabilities struct {
// `textDocument/diagnostic` request.
RefreshSupport bool `json:"refreshSupport,omitempty"`
}

type Hover struct {
Contents MarkupContent `json:"contents"`
Range *baselsp.Range `json:"range,omitempty"`
}

type MarkupContent struct {
Kind string `json:"kind,omitempty"`
Language string `json:"language,omitempty"`
Value string `json:"value"`
}
15 changes: 15 additions & 0 deletions pkg/cmd/lsp.go
Expand Up @@ -4,10 +4,16 @@ import (
"context"
"time"

"github.com/go-logr/zerologr"
"github.com/jzelinskie/cobrautil/v2"
"github.com/jzelinskie/cobrautil/v2/cobrazerolog"
"github.com/rs/zerolog"
"github.com/spf13/cobra"

"github.com/authzed/spicedb/internal/logging"
"github.com/authzed/spicedb/internal/lsp"
"github.com/authzed/spicedb/pkg/cmd/termination"
"github.com/authzed/spicedb/pkg/releases"
)

// LSPConfig is the configuration for the LSP command.
Expand All @@ -33,6 +39,15 @@ func NewLSPCommand(programName string, config *LSPConfig) *cobra.Command {
return &cobra.Command{
Use: "lsp",
Short: "serve language server protocol",
PreRunE: cobrautil.CommandStack(
cobrautil.SyncViperDotEnvPreRunE(programName, "spicedb.env", zerologr.New(&logging.Logger)),
cobrazerolog.New(
cobrazerolog.WithTarget(func(logger zerolog.Logger) {
logging.SetGlobalLogger(logger)
}),
).RunE(),
releases.CheckAndLogRunE(),
),
RunE: termination.PublishError(func(cmd *cobra.Command, args []string) error {
srv, err := config.Complete(cmd.Context())
if err != nil {
Expand Down

0 comments on commit d3150c2

Please sign in to comment.