oauth2认证与拦截器
类似java spring中的拦截器。gRpc也有拦截器的说法,拦截器可作用于客户端请求,服务端请求。对请求进行拦截,进行业务上的一些封装校验等,类似一个中间件的作用
拦截器类型
- 一元请求拦截器
- 流式请求拦截器
- 链式拦截器(一个个调用对应处理函数)
使用场景:
拦截器可以从元数据获取一些认证进行进行校验。
服务端拦截器
拦截器定义
interceptor.go
package serverimport ("context""errors""fmt""google.golang.org/grpc""google.golang.org/grpc/metadata""strings"
)// UnaryInterceptor 一元请求拦截器
func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {fmt.Println("server UnaryInterceptor:", info)if err := oauth2Valid(ctx); err != nil {return nil, err}return handler(ctx, req)
}// StreamInterceptor 流式拦截器
func StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {fmt.Println("server StreamInterceptor")fmt.Println(info)if err := oauth2Valid(ss.Context()); err != nil {return err}return handler(srv, ss)
}// oauth2认证,从上下文获取请求元数据
func oauth2Valid(ctx context.Context) error {md, ok := metadata.FromIncomingContext(ctx)if !ok {return errors.New("元数据获取失败, 身份认证失败")}authorization := md["authorization"]if !valid(authorization) {return errors.New("令牌校验不通过, 身份认证失败")}return nil
}func valid(authorization []string) bool {if len(authorization) < 1 {return false}token := strings.TrimPrefix(authorization[0], "Bearer ")return token == fetchToken()
}func fetchToken() string {return "some-secret-token"
}
拦截器配置
package mainimport ("flag""fmt""google.golang.org/grpc""grpc/echo""grpc/echo-server-practice/server""log""net"
)var (port = flag.Int("port", 50053, "port")
)func getOptions() (opts []grpc.ServerOption) {opts = make([]grpc.ServerOption, 0)opts = append(opts, server.GetMTlsOpt())// 附加一个拦截器,还有链式拦截器 ChainInterceptoropts = append(opts, grpc.UnaryInterceptor(server.UnaryInterceptor))opts = append(opts, grpc.StreamInterceptor(server.StreamInterceptor))return opts
}func main() {flag.Parse()lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))if err != nil {log.Fatal(err)}// grpc servers := grpc.NewServer(getOptions()...)......
}
客户端拦截器
客户端拦截器中简单实现利用oauth2做认证
package clientimport ("fmt""golang.org/x/net/context""google.golang.org/grpc""grpc/echo-client/client"
)// UnaryInterceptor 客户端一元请求拦截器func UnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {fmt.Println("server UnaryInterceptor: ", req)// 其实就proto生成的一元请求里的invoke差不多return invoker(ctx, method, req, reply, cc, opts...)
}
// StreamInterceptor 客户端流式拦截器
func StreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {fmt.Println("client streamInterceptor")// 和流式请求里的NewStream也是一样的return streamer(ctx, desc, cc, method, opts...)
}
如果客户端启动没配置,可以在拦截器中添加。
客户端 main.go
package mainimport ("flag""google.golang.org/grpc""grpc/echo""grpc/echo-client-practice/client""log"
)var (addr = flag.String("host", "localhost:50053", "")
)func getDiaOption() []grpc.DialOption {dialOptions := make([]grpc.DialOption, 0)dialOptions = append(dialOptions, client.GetMTlsOpt())dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(client.UnaryInterceptor))dialOptions = append(dialOptions, grpc.WithStreamInterceptor(client.StreamInterceptor))dialOptions = append(dialOptions, client.GetAuth(client.FetchToken()))return dialOptions
}func main() {flag.Parse()conn, err := grpc.Dial(*addr, getDiaOption()...)if err != nil {log.Fatal(err)}defer conn.Close()# 下面是伪代码c := echo.NewYourClient(conn)c.CallYourRpc(your_request)
}
链式拦截器执行核心原理
func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {n := len(interceptors)if n > 1 {lastI := n - 1return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {var (chainHandler grpc.UnaryInvokercurI int)chainHandler = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {if curI == lastI {return invoker(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentOpts...)}curI++err := interceptors[curI](currentCtx, currentMethod, currentReq, currentRepl, currentConn, chainHandler, currentOpts...)curI--return err}return interceptors[0](ctx, method, req, reply, cc, chainHandler, opts...)}}...
}
当拦截器数量大于 1 时,从 interceptors[1]
开始递归,每一个递归的拦截器 interceptors[i]
会不断地执行,最后才真正的去执行 handler
方法。