From cae06c5c618fb1bf7c3739044ed91df2d6391efc Mon Sep 17 00:00:00 2001 From: Philipp Heckel Date: Wed, 13 Jul 2022 20:31:17 -0400 Subject: [PATCH] Continued --- client/options.go | 4 ++++ cmd/publish.go | 32 ++++++++++++++++++++++++++++++-- crypto/crypto.go | 16 ++++++++-------- crypto/crypto_test.go | 15 +++++++++++---- server/server.go | 8 ++++++++ server/types.go | 8 +++++++- util/peek.go | 10 ++++++++++ 7 files changed, 78 insertions(+), 15 deletions(-) diff --git a/client/options.go b/client/options.go index 7d599699..26d9d56a 100644 --- a/client/options.go +++ b/client/options.go @@ -92,6 +92,10 @@ func WithNoFirebase() PublishOption { return WithHeader("X-Firebase", "no") } +func WithEncrypted() PublishOption { + return WithHeader("X-Encryption", "jwe") +} + // WithSince limits the number of messages returned from the server. The parameter since can be a Unix // timestamp (see WithSinceUnixTime), a duration (WithSinceDuration) the word "all" (see WithSinceAll). func WithSince(since string) SubscribeOption { diff --git a/cmd/publish.go b/cmd/publish.go index f07f783d..3cc2085a 100644 --- a/cmd/publish.go +++ b/cmd/publish.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/urfave/cli/v2" "heckel.io/ntfy/client" + "heckel.io/ntfy/crypto" "heckel.io/ntfy/log" "heckel.io/ntfy/util" "io" @@ -15,6 +16,10 @@ import ( "time" ) +const ( + encryptedMessageBytesLimit = 100 * 1024 * 1024 // 100 MB +) + func init() { commands = append(commands, cmdPublish) } @@ -100,7 +105,7 @@ func execPublish(c *cli.Context) error { noFirebase := c.Bool("no-firebase") quiet := c.Bool("quiet") pid := c.Int("wait-pid") - //password := os.Getenv("NTFY_PASSWORD") + password := os.Getenv("NTFY_PASSWORD") topic, message, command, err := parseTopicMessageCommand(c) if err != nil { return err @@ -193,6 +198,20 @@ func execPublish(c *cli.Context) error { } } } + if password != "" { + topicURL := expandTopicURL(topic, conf.DefaultHost) + key := crypto.DeriveKey(password, topicURL) + peaked, err := util.PeekLimit(io.NopCloser(body), encryptedMessageBytesLimit) + if err != nil { + return err + } + ciphertext, err := crypto.Encrypt(peaked.PeekedBytes, key) + if err != nil { + return err + } + body = strings.NewReader(ciphertext) + options = append(options, client.WithEncrypted()) + } cl := client.New(conf) m, err := cl.PublishReader(topic, body, options...) if err != nil { @@ -204,8 +223,17 @@ func execPublish(c *cli.Context) error { return nil } -// parseTopicMessageCommand reads the topic and the remaining arguments from the context. +func expandTopicURL(topic, defaultHost string) string { + if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") { + return topic + } else if strings.Contains(topic, "/") { + return fmt.Sprintf("https://%s", topic) + } + return fmt.Sprintf("%s/%s", defaultHost, topic) +} +// parseTopicMessageCommand reads the topic and the remaining arguments from the context. +// // There are a few cases to consider: // ntfy publish [] // ntfy publish --wait-cmd diff --git a/crypto/crypto.go b/crypto/crypto.go index 4ddcc2da..c07af19f 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -13,31 +13,31 @@ const ( keyDerivIter = 50000 ) -func DeriveKey(password string, topicURL string) []byte { +func DeriveKey(password, topicURL string) []byte { salt := sha256.Sum256([]byte(topicURL)) return pbkdf2.Key([]byte(password), salt[:], keyDerivIter, keyLenBytes, sha256.New) } -func Encrypt(plaintext string, key []byte) (string, error) { +func Encrypt(plaintext []byte, key []byte) (string, error) { enc, err := jose.NewEncrypter(jweEncryption, jose.Recipient{Algorithm: jweAlgorithm, Key: key}, nil) if err != nil { return "", err } - jwe, err := enc.Encrypt([]byte(plaintext)) + jwe, err := enc.Encrypt(plaintext) if err != nil { return "", err } return jwe.CompactSerialize() } -func Decrypt(input string, key []byte) (string, error) { - jwe, err := jose.ParseEncrypted(input) +func Decrypt(ciphertext string, key []byte) ([]byte, error) { + jwe, err := jose.ParseEncrypted(ciphertext) if err != nil { - return "", err + return nil, err } out, err := jwe.Decrypt(key) if err != nil { - return "", err + return nil, err } - return string(out), nil + return out, nil } diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index 79a51756..70fffde6 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -1,25 +1,32 @@ package crypto import ( + "fmt" "github.com/stretchr/testify/require" "testing" ) +func TestDeriveKey(t *testing.T) { + key := DeriveKey("secr3t password", "https://ntfy.sh/mysecret") + require.Equal(t, "30b7e72f6273da6e59d2dec535466e548da3eafc98650c9664c06edab707fa25", fmt.Sprintf("%x", key)) +} + func TestEncryptDecrypt(t *testing.T) { message := "this is a message or is it?" - ciphertext, err := Encrypt(message, []byte("AES256Key-32Characters1234567890")) + ciphertext, err := Encrypt([]byte(message), []byte("AES256Key-32Characters1234567890")) require.Nil(t, err) plaintext, err := Decrypt(ciphertext, []byte("AES256Key-32Characters1234567890")) require.Nil(t, err) - require.Equal(t, message, plaintext) + require.Equal(t, message, string(plaintext)) } func TestEncryptDecrypt_FromPHP(t *testing.T) { ciphertext := "eyJhbGciOiJkaXIiLCJlbmMiOiJBMjU2R0NNIn0..vbe1Qv_-mKYbUgce.EfmOUIUi7lxXZG_o4bqXZ9pmpr1Rzs4Y5QLE2XD2_aw_SQ.y2hadrN5b2LEw7_PJHhbcA" key := DeriveKey("secr3t password", "https://ntfy.sh/mysecret") + fmt.Printf("%x", key) plaintext, err := Decrypt(ciphertext, key) require.Nil(t, err) - require.Equal(t, `{"message":"Secret!","priority":5}`, plaintext) + require.Equal(t, `{"message":"Secret!","priority":5}`, string(plaintext)) } func TestEncryptDecrypt_FromPython(t *testing.T) { @@ -27,5 +34,5 @@ func TestEncryptDecrypt_FromPython(t *testing.T) { key := DeriveKey("secr3t password", "https://ntfy.sh/mysecret") plaintext, err := Decrypt(ciphertext, key) require.Nil(t, err) - require.Equal(t, `{"message":"Python says hi","tags":["secret"]}`, plaintext) + require.Equal(t, `{"message":"Python says hi","tags":["secret"]}`, string(plaintext)) } diff --git a/server/server.go b/server/server.go index 94f35801..74a31c4d 100644 --- a/server/server.go +++ b/server/server.go @@ -95,6 +95,7 @@ const ( newMessageBody = "New message" // Used in poll requests as generic message defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment encodingBase64 = "base64" + encodingJWE = "jwe" ) // WebSocket constants @@ -461,6 +462,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes if m.PollID != "" { m = newPollRequestMessage(t.ID, m.PollID) } + if m.Encoding == encodingJWE { + m = newEncryptedMessage(t.ID, m.Message) + } if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil { return nil, err } @@ -644,6 +648,10 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error()) } } + encryption := readParam(r, "x-encryption", "encryption", "encrypted", "encrypt", "enc") + if encryption == "yes" || encryption == "true" || encryption == "1" || encryption == encodingJWE { + m.Encoding = encodingJWE + } unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too! if unifiedpush { firebase = false diff --git a/server/types.go b/server/types.go index 44fe9e9e..d59e748b 100644 --- a/server/types.go +++ b/server/types.go @@ -33,7 +33,7 @@ type message struct { Attachment *attachment `json:"attachment,omitempty"` PollID string `json:"poll_id,omitempty"` Sender string `json:"-"` // IP address of uploader, used for rate limiting - Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes + Encoding string `json:"encoding,omitempty"` // empty for UTF-8, "base64", or "jwe" (encrypted) } type attachment struct { @@ -115,6 +115,12 @@ func newPollRequestMessage(topic, pollID string) *message { return m } +func newEncryptedMessage(topic, msg string) *message { + m := newMessage(messageEvent, topic, msg) + m.Encoding = encodingJWE + return m +} + func validMessageID(s string) bool { return util.ValidRandomString(s, messageIDLength) } diff --git a/util/peek.go b/util/peek.go index 40150cbc..499d5069 100644 --- a/util/peek.go +++ b/util/peek.go @@ -38,6 +38,16 @@ func Peek(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) { }, nil } +func PeekLimit(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) { + rc, err := Peek(underlying, limit) + if err != nil { + return nil, err + } else if rc.LimitReached { + return nil, ErrLimitReached + } + return rc, nil +} + // Read reads from the peeked bytes and then from the underlying stream func (r *PeekedReadCloser) Read(p []byte) (n int, err error) { if r.closed {