diff --git a/src/_utils.js b/src/_utils.js
index ea8d1d66..4dcaddf4 100644
--- a/src/_utils.js
+++ b/src/_utils.js
@@ -21,7 +21,7 @@ const joinSpreads = spreads => spreads.reduce((acc, curr) => or(acc, curr))
export const hashString = str => String(_hashString(str))
-export const addClassName = (path, jsxId) => {
+export const addClassName = (path, jsxId, classNameAttribute = 'className') => {
const jsxIdWithSpace = concat(jsxId, t.stringLiteral(' '))
const attributes = path.get('attributes')
const spreads = []
@@ -35,7 +35,7 @@ export const addClassName = (path, jsxId) => {
const properties = node.argument.properties
const index = properties.findIndex(
- property => property.key.name === 'className'
+ property => property.key.name === classNameAttribute
)
if (~index) {
@@ -59,7 +59,7 @@ export const addClassName = (path, jsxId) => {
: t.identifier(name)
const attrNameDotClassName = t.memberExpression(
spreadObj,
- t.identifier('className')
+ t.identifier(classNameAttribute)
)
spreads.push(
@@ -76,7 +76,7 @@ export const addClassName = (path, jsxId) => {
continue
}
- if (t.isJSXAttribute(attr) && node.name.name === 'className') {
+ if (t.isJSXAttribute(attr) && node.name.name === classNameAttribute) {
className = attributes[i]
// found className break the loop
break
@@ -105,7 +105,7 @@ export const addClassName = (path, jsxId) => {
}
path.node.attributes.push(
- t.jSXAttribute(t.jSXIdentifier('className'), className)
+ t.jSXAttribute(t.jSXIdentifier(classNameAttribute), className)
)
}
diff --git a/src/babel.js b/src/babel.js
index 8c01155d..aea23751 100644
--- a/src/babel.js
+++ b/src/babel.js
@@ -19,6 +19,11 @@ import {
import { STYLE_COMPONENT } from './_constants'
+const getClassNameAttributeNameForElementName = name => {
+ const isLowerCase = name.charAt(0) === name.charAt(0).toLowerCase()
+ return isLowerCase && name.includes('-') ? 'class' : 'className'
+}
+
export default function({ types: t }) {
const jsxVisitors = {
JSXOpeningElement(path, state) {
@@ -49,8 +54,9 @@ export default function({ types: t }) {
binding.referencePaths.some(r => r === tag)
))
) {
+ const classNameAttribute = getClassNameAttributeNameForElementName(name)
if (state.className) {
- addClassName(path, state.className)
+ addClassName(path, state.className, classNameAttribute)
}
}
diff --git a/test/attribute.js b/test/attribute.js
index 441a4322..d9cd23ee 100644
--- a/test/attribute.js
+++ b/test/attribute.js
@@ -18,6 +18,13 @@ test('rewrites className', async t => {
t.snapshot(code)
})
+test('rewrites class for custom components', async t => {
+ const { code } = await transform(
+ './fixtures/attribute-generation-custom-component-class-rewriting.js'
+ )
+ t.snapshot(code)
+})
+
test('generate attribute for mixed modes (global, static, dynamic)', async t => {
const { code } = await transform('./fixtures/attribute-generation-modes.js')
t.snapshot(code)
diff --git a/test/fixtures/attribute-generation-custom-component-class-rewriting.js b/test/fixtures/attribute-generation-custom-component-class-rewriting.js
new file mode 100644
index 00000000..e3744dcf
--- /dev/null
+++ b/test/fixtures/attribute-generation-custom-component-class-rewriting.js
@@ -0,0 +1,11 @@
+import test from 'ava'
+
+export default () => {
+ const Element = 'custom-component'
+ return (
+